From 5cd3c0fb3496f7a0d15ad765e041a40ac4e85866 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 12:00:57 +0800 Subject: [PATCH 01/60] docs: add transformers resume design spec --- ...7-transformers-checkpoint-resume-design.md | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md new file mode 100644 index 00000000..a821402a --- /dev/null +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -0,0 +1,353 @@ +# Transformers Strict Resume Design + +## Summary + +This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. + +The implementation aligns the resume semantics of `TransformersModel` with the existing `MegatronModel` behavior: + +- normal weight loading remains available +- strict resume restores model weights and training state together +- strict resume does not silently fall back to weight-only loading when required state is missing + +Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. + +## Goals + +- Support true checkpoint resume for `TransformersModel` +- Restore model weights, optimizer state, scheduler state, RNG state, and step counters +- Support dataset progress skipping for map-style datasets +- Expose Swift-like resume controls without adding a new trainer class +- Preserve existing weight-only loading and saving behavior +- Keep backward compatibility for existing checkpoints where possible + +## Non-Goals + +- Do not introduce a new `Trainer` class or resume manager class +- Do not guarantee exact sample-by-sample replay when retry-based sampling changes sample order +- Do not support exact data-progress resume for `IterableDataset` or streaming datasets +- Do not attempt to persist transient runtime state such as in-flight batch tensors, current loss tensors, or metric caches + +## User-Facing Resume Controls + +Resume behavior is controlled by existing training entrypoints through three new parameters: + +- `resume_from_checkpoint: Optional[str] = None` +- `resume_only_model: bool = False` +- `ignore_data_skip: bool = False` + +### Parameter semantics + +#### `resume_from_checkpoint` + +- Specifies the checkpoint directory or checkpoint path to resume from +- When unset, training starts normally from scratch +- When set, the training entrypoint reads the checkpoint and restores model state through existing model APIs + +#### `resume_only_model` + +- Defaults to `False` +- When `False`, resume restores full training state +- When `True`, resume restores only model weights + +#### `ignore_data_skip` + +- Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` +- Defaults to `False` +- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, or RNG +- When `True`, the system restores only model weights and does not restore training progress or skip consumed data + +### Effective behavior matrix + +#### Case 1: `resume_from_checkpoint is None` + +- Start a new training run + +#### Case 2: `resume_from_checkpoint is not None` and `resume_only_model=False` + +- Restore model weights +- Restore optimizer state +- Restore scheduler state +- Restore RNG state +- Restore step counters +- Attempt to skip already consumed training data +- If required model training state is missing, fail without fallback + +#### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` + +- Restore model weights only +- Do not restore optimizer, scheduler, or RNG +- Restore step/progress metadata needed for data skipping +- Attempt to skip already consumed training data + +#### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` + +- Restore model weights only +- Do not restore optimizer, scheduler, RNG, step counters, or data progress +- Restart the training loop from step 0 with no skipping + +## Checkpoint Layout + +Existing checkpoint layout remains valid. New resume metadata is added alongside current files. + +### Existing files preserved + +- model weights saved by `save_pretrained` +- LoRA weights saved as `adapter_model.safetensors` +- tokenizer artifacts +- `optimizer.pt` +- `scheduler.pt` + +### New file + +- `training_state.pt` + +### `training_state.pt` contents + +`training_state.pt` stores a small dictionary with the following fields: + +- `checkpoint_version` +- `cur_step` +- `gradient_accumulation_steps` +- `scaler_state_dict` +- `scaler_has_nan` +- `rng_state` +- `data_progress` + +### `rng_state` contents + +- Python `random` state +- NumPy RNG state +- PyTorch CPU RNG state +- CUDA RNG state + +### `data_progress` contents + +First version stores progress in a compact form: + +- `consumed_train_samples` +- optionally `consumed_batches` when this is easier to compute reliably in a given entrypoint + +The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. + +## Model Save and Load Semantics + +## `TransformersModel.save` + +`TransformersModel.save(..., save_optimizer=True)` is extended to: + +1. Save weights exactly as today +2. Save tokenizer exactly as today +3. Save `optimizer.pt` and `scheduler.pt` exactly as today +4. Save `training_state.pt` + +When `save_optimizer=False`, save remains weight-only and does not produce strict resume metadata. + +## `TransformersModel.load` + +`TransformersModel.load(..., load_optimizer=False)` keeps current behavior: + +- load model weights only + +`TransformersModel.load(..., load_optimizer=True)` becomes strict model-state resume: + +1. Resolve checkpoint directory +2. Load model weights +3. Load optimizer and scheduler state +4. Load `training_state.pt` +5. Restore scaler state +6. Restore RNG state +7. Restore `cur_step` and `gradient_accumulation_steps` + +### Failure behavior + +When `load_optimizer=True`, missing required model training state is an error: + +- missing `training_state.pt` -> fail +- missing `optimizer.pt` when optimizer restore is required -> fail +- missing `scheduler.pt` when scheduler restore is required -> fail +- malformed required fields in `training_state.pt` -> fail + +This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. + +This matches the `MegatronModel` contract more closely than the current `TransformersModel` behavior. + +## Training Progress and Data Skipping + +Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. + +## Dataloader extensions + +Existing dataloader and sampler code is extended rather than replaced: + +- `twinkle.dataloader.DataLoader` +- `twinkle.dataloader.DeviceMeshSampler` +- retry-aware sampler flow + +The dataloader gains resume-oriented arguments: + +- `skip_samples: int = 0` +- optionally `skip_batches: int = 0` + +Map-style datasets use this progress to skip already consumed data before yielding new training batches. + +## Map-style dataset behavior + +For datasets with `__len__`, Twinkle attempts to skip previously consumed data using sampler or batch-sampler level skipping. + +Preferred behavior: + +- preserve existing sharding logic +- apply skip before data is yielded to the training loop +- keep the solution compatible with current `DeviceMeshSampler` wrapping + +## Iterable and streaming behavior + +`IterableDataset` and streaming datasets do not support exact progress skipping in this design. + +Behavior for these datasets: + +- restore model state according to the selected resume mode +- log a clear warning that consumed-data skipping is not supported +- continue training without skipping historical samples + +This is the only fallback allowed in the design. It applies only to dataset progress skipping, not to model-state resume. + +## Entry Point Integration + +No new trainer class is introduced. + +Resume parameters are threaded through existing training entrypoints: + +- direct local training loops using `TwinkleModel` / `TransformersModel` +- current client/server training flows that already support checkpoint save and load + +The practical integration model is: + +1. Parse or receive the three resume parameters +2. If `resume_from_checkpoint` is unset, construct dataloader normally +3. If `resume_only_model=False`, call existing model load with strict restore semantics +4. If `resume_only_model=True`, call weight-only model load +5. If data skipping is enabled, read progress metadata from `training_state.pt` +6. Recreate the dataloader with skip arguments applied + +This keeps the training loop explicit and compatible with current Twinkle examples. + +## Server and Client Behavior + +Server-side checkpoint save/load behavior should preserve current APIs while adding richer metadata. + +### Save path + +When server-side save endpoints request optimizer save: + +- save the model checkpoint as today +- save `optimizer.pt`, `scheduler.pt`, and `training_state.pt` +- persist checkpoint metadata through the existing checkpoint manager + +### Load path + +Current `load_optimizer=True` behavior is retained as the trigger for strict model-state restore. + +The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: + +- call strict resume +- call weight-only resume +- request data skipping + +The underlying server model APIs do not need a new trainer object to support this. + +## Compatibility Strategy + +### Existing checkpoints + +Existing checkpoints remain loadable in weight-only mode. + +Examples: + +- `model.load(path, load_optimizer=False)` continues to work +- inference-only consumers remain unaffected + +### Old checkpoints under strict resume + +Old checkpoints that lack `training_state.pt` are not valid for strict `TransformersModel` resume. + +Expected behavior: + +- strict resume fails clearly +- weight-only load continues to work when requested explicitly + +### `resume_only_model=True` + +For `resume_only_model=True`, old checkpoints may still be usable if weight files are present. + +If data skipping is requested but no progress metadata exists, the entrypoint should fail clearly rather than silently train from the beginning while claiming resumed progress. + +## Risks and Constraints + +### RetrySampler interaction + +`RetrySampler` may retry or replace failed samples, including random backfill behavior at the tail of an epoch. + +Because of that: + +- progress skipping can preserve approximate data position +- exact sample-for-sample replay is not guaranteed when retry or backfill paths are exercised + +This limitation should be documented explicitly. + +### Dataset shape changes + +If dataset definition, slicing, filtering, or shuffle configuration changes between save and resume, data skipping semantics may become invalid. + +The user guidance should state that resume should be done with unchanged training parameters and unchanged dataset configuration. + +### Distributed consistency + +Skip logic must be compatible with current device-mesh sharding. The implementation should ensure skip is applied consistently before per-rank slicing causes divergence. + +## Testing Strategy + +Tests should cover: + +### Model-state save/load + +- `training_state.pt` is written when optimizer save is enabled +- scaler, RNG, `cur_step`, and accumulation settings are restored +- strict resume fails when required files are missing + +### Weight-only compatibility + +- legacy checkpoints still load in weight-only mode +- `resume_only_model=True` restores weights without optimizer and RNG + +### Data progress skipping + +- map-style datasets skip consumed data correctly +- skip behavior remains correct with device-mesh sharding +- iterable and streaming datasets emit warnings and continue without skipping + +### Failure cases + +- missing progress metadata when data skipping is requested +- malformed `training_state.pt` +- mismatch between requested strict resume and available checkpoint contents + +## Implementation Outline + +1. Extend `TransformersModel.save/load` to persist and restore `training_state.pt` +2. Add helper methods for RNG save/load and training-state serialization +3. Extend dataloader and sampler stack to support skip arguments for map-style datasets +4. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints +5. Add warnings for unsupported iterable/streaming data skipping +6. Update docs and examples to prefer trainer-level resume parameters over ad hoc `model.load(..., load_optimizer=True)` logic + +## User Guidance + +Recommended guidance text: + +- To resume training, keep other parameters unchanged and provide `resume_from_checkpoint` +- `resume_only_model=False` performs full resume +- `resume_only_model=True` restores only model weights +- `ignore_data_skip=True` disables progress restore and starts from step 0 +- Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data From 91eeaebb0077b6f2ea456fef1dc64a653ec8eaa8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:33:11 +0800 Subject: [PATCH 02/60] docs: refine transformers resume design spec --- ...7-transformers-checkpoint-resume-design.md | 200 +++++++++++------- 1 file changed, 125 insertions(+), 75 deletions(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index a821402a..9a1d45c2 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -4,22 +4,23 @@ This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. -The implementation aligns the resume semantics of `TransformersModel` with the existing `MegatronModel` behavior: +The design supports both full-parameter training and LoRA training: -- normal weight loading remains available -- strict resume restores model weights and training state together -- strict resume does not silently fall back to weight-only loading when required state is missing +- full-parameter training restores weights during model initialization +- LoRA training restores adapter weights through the existing load path +- both modes share the same training-state resume contract +- strict model-state resume does not silently fall back to weight-only loading when required state is missing Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. ## Goals - Support true checkpoint resume for `TransformersModel` -- Restore model weights, optimizer state, scheduler state, RNG state, and step counters +- Support both full-parameter and LoRA training resume +- Restore optimizer state, scheduler state, scaler state, RNG state, and step counters - Support dataset progress skipping for map-style datasets - Expose Swift-like resume controls without adding a new trainer class - Preserve existing weight-only loading and saving behavior -- Keep backward compatibility for existing checkpoints where possible ## Non-Goals @@ -54,7 +55,7 @@ Resume behavior is controlled by existing training entrypoints through three new - Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` - Defaults to `False` -- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, or RNG +- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, scaler, or RNG - When `True`, the system restores only model weights and does not restore training progress or skip consumed data ### Effective behavior matrix @@ -68,6 +69,7 @@ Resume behavior is controlled by existing training entrypoints through three new - Restore model weights - Restore optimizer state - Restore scheduler state +- Restore scaler state - Restore RNG state - Restore step counters - Attempt to skip already consumed training data @@ -76,102 +78,141 @@ Resume behavior is controlled by existing training entrypoints through three new #### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` - Restore model weights only -- Do not restore optimizer, scheduler, or RNG +- Do not restore optimizer, scheduler, scaler, or RNG - Restore step/progress metadata needed for data skipping - Attempt to skip already consumed training data #### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` - Restore model weights only -- Do not restore optimizer, scheduler, RNG, step counters, or data progress +- Do not restore optimizer, scheduler, scaler, RNG, step counters, or data progress - Restart the training loop from step 0 with no skipping ## Checkpoint Layout -Existing checkpoint layout remains valid. New resume metadata is added alongside current files. +Existing weight layouts remain valid. New training-state files are added alongside current checkpoint contents. ### Existing files preserved -- model weights saved by `save_pretrained` +- full-model weights saved by `save_pretrained` - LoRA weights saved as `adapter_model.safetensors` - tokenizer artifacts - `optimizer.pt` - `scheduler.pt` -### New file +### New training-state files -- `training_state.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -### `training_state.pt` contents +### `trainer_state.json` contents -`training_state.pt` stores a small dictionary with the following fields: +`trainer_state.json` stores lightweight training metadata: - `checkpoint_version` - `cur_step` - `gradient_accumulation_steps` -- `scaler_state_dict` -- `scaler_has_nan` -- `rng_state` -- `data_progress` +- `consumed_train_samples` +- optionally `consumed_batches` + +The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. + +### `scaler.pt` contents -### `rng_state` contents +- AMP scaler state dict +- optional scaler-related flags such as `scaler_has_nan` + +### `rng_state.pt` contents - Python `random` state - NumPy RNG state - PyTorch CPU RNG state - CUDA RNG state -### `data_progress` contents +## Restore Paths -First version stores progress in a compact form: +## Full-Parameter Training -- `consumed_train_samples` -- optionally `consumed_batches` when this is easier to compute reliably in a given entrypoint +For full-parameter training, model weights are restored during initialization. -The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. +### Full-parameter restore flow -## Model Save and Load Semantics +1. Construct `TransformersModel(model_id=ckpt_dir, ...)` +2. `__init__` uses `from_pretrained(ckpt_dir, ...)` to restore weights +3. Create optimizer, scheduler, and scaler objects +4. Call `load_training_state(ckpt_dir)` to restore training state +5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` + +This means full-parameter resume does not need a separate model-weight loading method after initialization. It only needs explicit training-state restoration. + +## LoRA Training + +For LoRA training, the existing adapter-weight load path remains in place. + +### LoRA restore flow + +1. Construct the model and adapter objects as today +2. Restore adapter weights through the existing `load()` path +3. Create optimizer, scheduler, and scaler objects +4. Call the same `load_training_state(ckpt_dir)` method to restore training state +5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` + +## Unified training-state method + +The model layer gains a shared helper such as `load_training_state(ckpt_dir)`. + +This method restores: -## `TransformersModel.save` +- `optimizer.pt` +- `scheduler.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -`TransformersModel.save(..., save_optimizer=True)` is extended to: +It assumes the corresponding optimizer, scheduler, and scaler objects have already been created before invocation. + +## Model Save and Load Semantics -1. Save weights exactly as today -2. Save tokenizer exactly as today -3. Save `optimizer.pt` and `scheduler.pt` exactly as today -4. Save `training_state.pt` +## Save behavior -When `save_optimizer=False`, save remains weight-only and does not produce strict resume metadata. +When saving with optimizer state enabled, the checkpoint includes: -## `TransformersModel.load` +- weights in the existing full-model or LoRA format +- tokenizer artifacts +- `optimizer.pt` +- `scheduler.pt` +- `scaler.pt` +- `trainer_state.json` +- `rng_state.pt` -`TransformersModel.load(..., load_optimizer=False)` keeps current behavior: +When optimizer save is disabled, save remains weight-only and does not produce strict resume metadata. -- load model weights only +## Strict training-state restore -`TransformersModel.load(..., load_optimizer=True)` becomes strict model-state resume: +Strict model-state resume restores: -1. Resolve checkpoint directory -2. Load model weights -3. Load optimizer and scheduler state -4. Load `training_state.pt` -5. Restore scaler state -6. Restore RNG state -7. Restore `cur_step` and `gradient_accumulation_steps` +- optimizer state +- scheduler state +- scaler state +- RNG state +- `cur_step` +- `gradient_accumulation_steps` +- data-progress metadata ### Failure behavior -When `load_optimizer=True`, missing required model training state is an error: +When strict training-state restore is requested, missing required model training state is an error: -- missing `training_state.pt` -> fail +- missing `trainer_state.json` -> fail - missing `optimizer.pt` when optimizer restore is required -> fail - missing `scheduler.pt` when scheduler restore is required -> fail -- malformed required fields in `training_state.pt` -> fail +- missing `scaler.pt` when scaler restore is required -> fail +- missing `rng_state.pt` when RNG restore is required -> fail +- malformed required fields -> fail This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. -This matches the `MegatronModel` contract more closely than the current `TransformersModel` behavior. - ## Training Progress and Data Skipping Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. @@ -226,10 +267,12 @@ The practical integration model is: 1. Parse or receive the three resume parameters 2. If `resume_from_checkpoint` is unset, construct dataloader normally -3. If `resume_only_model=False`, call existing model load with strict restore semantics -4. If `resume_only_model=True`, call weight-only model load -5. If data skipping is enabled, read progress metadata from `training_state.pt` -6. Recreate the dataloader with skip arguments applied +3. Construct model weights through the appropriate path + - full-parameter: restore through `__init__` + - LoRA: restore through existing adapter load logic +4. If `resume_only_model=False`, call `load_training_state(ckpt_dir)` +5. If `resume_only_model=True` and `ignore_data_skip=False`, read `trainer_state.json` for progress only +6. Recreate the dataloader with skip arguments applied when skipping is enabled This keeps the training loop explicit and compatible with current Twinkle examples. @@ -242,17 +285,17 @@ Server-side checkpoint save/load behavior should preserve current APIs while add When server-side save endpoints request optimizer save: - save the model checkpoint as today -- save `optimizer.pt`, `scheduler.pt`, and `training_state.pt` +- save `optimizer.pt`, `scheduler.pt`, `scaler.pt`, `trainer_state.json`, and `rng_state.pt` - persist checkpoint metadata through the existing checkpoint manager ### Load path -Current `load_optimizer=True` behavior is retained as the trigger for strict model-state restore. +Current model load APIs remain the weight-loading trigger. The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: -- call strict resume -- call weight-only resume +- restore full training state +- restore weight only - request data skipping The underlying server model APIs do not need a new trainer object to support this. @@ -265,12 +308,13 @@ Existing checkpoints remain loadable in weight-only mode. Examples: -- `model.load(path, load_optimizer=False)` continues to work +- weight-only initialization for full-parameter checkpoints continues to work +- existing LoRA weight loading continues to work - inference-only consumers remain unaffected ### Old checkpoints under strict resume -Old checkpoints that lack `training_state.pt` are not valid for strict `TransformersModel` resume. +Old checkpoints that lack the new training-state files are not valid for strict resume. Expected behavior: @@ -310,16 +354,25 @@ Skip logic must be compatible with current device-mesh sharding. The implementat Tests should cover: -### Model-state save/load +### Full-parameter training resume + +- initializing with `model_id=ckpt_dir` restores weights +- `load_training_state(ckpt_dir)` restores optimizer, scheduler, scaler, RNG, and step metadata + +### LoRA training resume + +- adapter-weight restore continues to work +- `load_training_state(ckpt_dir)` restores shared training state correctly + +### Strict restore failures -- `training_state.pt` is written when optimizer save is enabled -- scaler, RNG, `cur_step`, and accumulation settings are restored - strict resume fails when required files are missing +- malformed state files fail clearly ### Weight-only compatibility - legacy checkpoints still load in weight-only mode -- `resume_only_model=True` restores weights without optimizer and RNG +- `resume_only_model=True` restores weights without optimizer, scheduler, scaler, or RNG ### Data progress skipping @@ -327,20 +380,16 @@ Tests should cover: - skip behavior remains correct with device-mesh sharding - iterable and streaming datasets emit warnings and continue without skipping -### Failure cases - -- missing progress metadata when data skipping is requested -- malformed `training_state.pt` -- mismatch between requested strict resume and available checkpoint contents - ## Implementation Outline -1. Extend `TransformersModel.save/load` to persist and restore `training_state.pt` -2. Add helper methods for RNG save/load and training-state serialization -3. Extend dataloader and sampler stack to support skip arguments for map-style datasets -4. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints -5. Add warnings for unsupported iterable/streaming data skipping -6. Update docs and examples to prefer trainer-level resume parameters over ad hoc `model.load(..., load_optimizer=True)` logic +1. Add model helpers for saving and loading split training-state files +2. Implement `load_training_state(ckpt_dir)` with shared behavior for full-parameter and LoRA training +3. Keep full-parameter weight restore in `__init__` +4. Keep LoRA weight restore in the existing adapter load path +5. Extend dataloader and sampler stack to support skip arguments for map-style datasets +6. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints +7. Add warnings for unsupported iterable and streaming data skipping +8. Update docs and examples to show the new resume contract ## User Guidance @@ -350,4 +399,5 @@ Recommended guidance text: - `resume_only_model=False` performs full resume - `resume_only_model=True` restores only model weights - `ignore_data_skip=True` disables progress restore and starts from step 0 +- Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data From 6eebda8d049abc7fc045b1d445502c192a81424b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:41:55 +0800 Subject: [PATCH 03/60] docs: trim resume state fields --- .../specs/2026-03-27-transformers-checkpoint-resume-design.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 9a1d45c2..3b38a910 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -114,7 +114,6 @@ Existing weight layouts remain valid. New training-state files are added alongsi - `cur_step` - `gradient_accumulation_steps` - `consumed_train_samples` -- optionally `consumed_batches` The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. From cdd9c1bc44fed6087ab278b04c27d7e06faa7237 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 15:44:30 +0800 Subject: [PATCH 04/60] docs: add npu resume compatibility requirements --- ...7-transformers-checkpoint-resume-design.md | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 3b38a910..7b90baba 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -20,6 +20,7 @@ Because Twinkle keeps the training loop explicit in user code, the design extend - Restore optimizer state, scheduler state, scaler state, RNG state, and step counters - Support dataset progress skipping for map-style datasets - Expose Swift-like resume controls without adding a new trainer class +- Keep training-state save and load compatible with NPU (Ascend) environments - Preserve existing weight-only loading and saving behavior ## Non-Goals @@ -129,6 +130,35 @@ The design prefers storing `consumed_train_samples` as the canonical progress va - PyTorch CPU RNG state - CUDA RNG state +## Accelerator Compatibility + +Training-state save and load must be accelerator-compatible, including Ascend NPU environments already supported by Twinkle. + +### Device-agnostic serialization + +Training-state files must use device-agnostic serialization: + +- optimizer, scheduler, scaler, and RNG payloads should be serialized in CPU-safe form +- JSON metadata stays in plain text files +- loading should first read state from CPU-safe files and then apply it to objects created on the current runtime device + +This avoids tying resume files to a specific device object layout during save. + +### RNG compatibility requirements + +RNG save and restore must branch by current accelerator backend: + +- CUDA runtime uses `torch.cuda` RNG APIs +- NPU runtime uses `torch.npu` RNG APIs +- CPU RNG and Python/NumPy RNG are always restored + +The implementation must not assume CUDA-only RNG helpers when saving or restoring training state. + +### Scope of compatibility + +The design requires resume support to work correctly in NPU environments. + +The design does not require cross-accelerator resume guarantees such as saving on GPU and resuming on NPU, or saving on NPU and resuming on GPU. The compatibility target is correct save and restore within the active supported accelerator backend. ## Restore Paths ## Full-Parameter Training @@ -400,3 +430,4 @@ Recommended guidance text: - `ignore_data_skip=True` disables progress restore and starts from step 0 - Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data + From 1542492f82db3e2a77abe9e38886e33bd2d2b005 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 27 Mar 2026 16:17:12 +0800 Subject: [PATCH 05/60] chore: ignore local worktrees --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 58f495d4..afdfcae9 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ images /custom/ megatron_output/ .qoder +.worktrees/ # Pytorch *.pth From 98831180162fab51893e04ff877f27a52ad843d2 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 08:57:49 +0800 Subject: [PATCH 06/60] wip --- client_tools/client_generator.py | 23 +- ...00\344\275\263\345\256\236\350\267\265.md" | 24 +- src/twinkle/dataloader/dataloader.py | 53 ++- src/twinkle/dataloader/device_mesh_sampler.py | 15 +- src/twinkle/dataloader/retry_sampler.py | 20 +- .../transformers/multi_lora_transformers.py | 2 +- .../model/transformers/transformers.py | 124 +++++++- src/twinkle/server/model/twinkle_handlers.py | 51 +++ .../model/multi_lora_transformers.py | 18 ++ src/twinkle_client/types/__init__.py | 3 + src/twinkle_client/types/model.py | 21 ++ tests/dataloader/test_dataloader.py | 69 +++- tests/dataloader/test_sampler.py | 46 +++ .../transformers/test_checkpoint_resume.py | 301 ++++++++++++++++++ .../model/test_twinkle_resume_routes.py | 180 +++++++++++ 15 files changed, 924 insertions(+), 26 deletions(-) create mode 100644 tests/model/transformers/test_checkpoint_resume.py create mode 100644 tests/server/model/test_twinkle_resume_routes.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 8927315d..f724c7c1 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. +# Copyright (c) ModelScope Contributors. All rights reserved. import ast from pathlib import Path from typing import Dict, List, Set, Tuple @@ -448,6 +448,7 @@ def generate_models(): GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -617,6 +618,23 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> None: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( @@ -850,4 +868,5 @@ def apply_patch(self, patch_cls: str, **kwargs) -> None: generate_samplers() print('\n' + '=' * 60) - print('\n✓ All client code generation complete!\n') + print('\nAll client code generation complete!\n') + diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index ad78e28d..f0112042 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -410,9 +410,19 @@ def train(): model.set_lr_scheduler('LinearLR') # 恢复训练(如有检查点) - if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + resume_from_checkpoint = resume_path + resume_only_model = False + ignore_data_skip = False + if resume_from_checkpoint: + logger.info(f'Resuming training from {resume_from_checkpoint}') + model.load(name=resume_from_checkpoint) + + if not resume_only_model: + trainer_state = model.load_training_state(resume_from_checkpoint) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + elif not ignore_data_skip: + progress = model.read_training_progress(resume_from_checkpoint) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) logger.info(model.get_train_configs()) @@ -445,6 +455,14 @@ if __name__ == '__main__': - 支持断点续训、检查点管理 - 可动态切换 LoRA 适配器、损失函数、优化器等组件 +Resume 模式: + +- `resume_from_checkpoint=None`:开始新的训练任务。 +- `resume_only_model=False`:恢复权重、optimizer、scheduler、scaler、RNG 和进度元数据。 +- `resume_only_model=True` 且 `ignore_data_skip=False`:恢复权重,读取进度元数据,并跳过已消费样本。 +- `resume_only_model=True` 且 `ignore_data_skip=True`:只恢复权重,训练步数和数据进度从 0 开始。 +- `skip_consumed_samples(...)` 不适用于 iterable / streaming dataset。 + ### 3.2 Tinker Client:简洁即用 Tinker 是一个轻量级训练 API。Twinkle 对 Tinker 客户端提供完整支持,几行代码就能拉起训练。已有 Tinker 代码的项目可以直接迎移到 Twinkle 服务端。 diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..e2ef57ce 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -51,6 +53,9 @@ def __init__(self, self.dataloader_params['batch_size'] = batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None + self._skip_samples = 0 + self._base_batch_sampler = None + self._base_sampler = None self._set_work_init_fn() def _set_work_init_fn(self): @@ -97,7 +102,9 @@ def _lazy_init_dataloader(self): if not isinstance(self.dataset, IterableDataset): self.dataloader.__initialized = False - self._repeat_sample_and_shard() + self._base_batch_sampler = self.dataloader.batch_sampler + self._base_sampler = self.dataloader.sampler + self._rebuild_sampler_stack() self.dataloader.__initialized = True @remote_function() @@ -116,11 +123,39 @@ def __iter__(self): max_retries=self.max_retries) return _iter - def _repeat_sample_and_shard(self): - if self.dataloader.batch_sampler is not None and hasattr(self.dataloader.batch_sampler, 'sampler'): - self.dataloader.batch_sampler.sampler = RetrySampler( - self.dataloader.batch_sampler.sampler, self.dataset, max_retries=self.max_retries) - self.dataloader.batch_sampler = DeviceMeshSampler(self.dataloader.batch_sampler, self.device_mesh, - self.min_batch_size) - elif self.dataloader.sampler is not None: - self.dataloader.sampler = RetrySampler(self.dataloader.sampler, self.dataset, max_retries=self.max_retries) + @remote_function() + def skip_consumed_samples(self, consumed_train_samples: int) -> None: + from torch.utils.data import IterableDataset + + if isinstance(self.dataset, IterableDataset): + warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') + self._skip_samples = 0 + return + + self._skip_samples = max(int(consumed_train_samples), 0) + if self.dataloader is not None: + self.dataloader.__initialized = False + self._rebuild_sampler_stack() + self.dataloader.__initialized = True + + def _rebuild_sampler_stack(self): + if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): + batch_sampler = copy.copy(self._base_batch_sampler) + batch_sampler.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + ) + self.dataloader.batch_sampler = DeviceMeshSampler( + batch_sampler, + self.device_mesh, + self.min_batch_size, + skip_samples=self._skip_samples, + ) + elif self._base_sampler is not None: + self.dataloader.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + skip_samples=self._skip_samples, + ) diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py index 955b85cd..1f649de3 100644 --- a/src/twinkle/dataloader/device_mesh_sampler.py +++ b/src/twinkle/dataloader/device_mesh_sampler.py @@ -12,15 +12,28 @@ class DeviceMeshSampler(BatchSampler): device_mesh: The device mesh. """ - def __init__(self, original_sampler: BatchSampler, device_mesh: DeviceMesh, min_batch_size: int = None): + def __init__(self, + original_sampler: BatchSampler, + device_mesh: DeviceMesh, + min_batch_size: int = None, + skip_samples: int = 0): self.original_sampler = original_sampler self.device_mesh = device_mesh self.min_batch_size = min_batch_size + self.skip_samples = skip_samples if self.min_batch_size is None and self.device_mesh is not None: self.min_batch_size = self.device_mesh.data_world_size def __iter__(self): + skipped = 0 for batch in self.original_sampler: + if skipped < self.skip_samples: + if skipped + len(batch) <= self.skip_samples: + skipped += len(batch) + continue + batch = batch[self.skip_samples - skipped:] + skipped = self.skip_samples + if not self.device_mesh: yield batch else: diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 62f05660..43307b1a 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,13 +14,16 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20): + def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20, skip_samples: int = 0): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries + self.skip_samples = skip_samples def __iter__(self): - total = 0 + emitted = 0 + seen_valid = 0 + target_total = max(len(self.dataset) - self.skip_samples, 0) for idx in self.original_sampler: for _ in range(self.max_retries): try: @@ -29,8 +32,11 @@ def __iter__(self): data = self.dataset[idx] if not data: continue + seen_valid += 1 + if seen_valid <= self.skip_samples: + break yield idx - total += 1 + emitted += 1 break except Exception: # noqa import traceback @@ -39,12 +45,11 @@ def __iter__(self): else: raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') - origin_dataset_len = len(self.dataset) - if total >= origin_dataset_len: + if emitted >= target_total: return for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): - if total >= origin_dataset_len: + if emitted >= target_total: raise StopIteration for _ in range(self.max_retries): try: @@ -53,7 +58,8 @@ def __iter__(self): if not data: continue yield idx - total += 1 + emitted += 1 + break except Exception: # noqa import traceback traceback.print_exc() diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 0900f52b..7db01146 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -237,7 +237,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self.load_training_state(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..d588b8e7 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -3,8 +3,10 @@ import contextlib import json import os +import random import re import threading +import numpy as np import torch import torch.distributed as dist import transformers @@ -866,6 +868,11 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if kwargs.get('save_optimizer', False): self._save_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._save_training_state( + checkpoint_dir, + adapter_name=adapter_name, + consumed_train_samples=kwargs.get('consumed_train_samples', 0), + ) return checkpoint_dir @@ -881,6 +888,33 @@ def _save_optimizer(self, output_dir, **kwargs): if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + def _save_training_state(self, output_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not Platform.is_master(): + return + + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + } + with open(os.path.join(output_dir, 'trainer_state.json'), 'w', encoding='utf-8') as f: + json.dump(trainer_state, f) + + if optimizer_config.scaler is not None: + torch.save( + { + 'scaler_state_dict': optimizer_config.scaler.state_dict(), + 'scaler_has_nan': optimizer_config.scaler_has_nan, + }, + os.path.join(output_dir, 'scaler.pt'), + ) + + torch.save(self._get_training_rng_state(), os.path.join(output_dir, 'rng_state.pt')) + def _save_tokenizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] @@ -946,20 +980,106 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): def _load_optimizer(self, checkpoint_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + strict = kwargs.pop('strict', False) # assume optimizer and lr_scheduler are created optimizer_config = self.optimizer_group[adapter_name] optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt') scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt') + if strict and not os.path.exists(optimizer_path): + raise FileNotFoundError(optimizer_path) + if strict and not os.path.exists(scheduler_path): + raise FileNotFoundError(scheduler_path) + if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu') + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu') + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) optimizer_config.lr_scheduler.load_state_dict(state_dict) + def _load_scaler_state(self, scaler_path, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + if optimizer_config.scaler is None: + raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') + + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=False) + optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) + optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) + + def _get_training_rng_state(self): + state = { + 'python_rng_state': random.getstate(), + 'numpy_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + if hasattr(torch, 'npu') and torch.npu.is_available(): + state['device_type'] = 'npu' + state['device_rng_state'] = torch.npu.get_rng_state() + elif torch.cuda.is_available(): + state['device_type'] = 'cuda' + state['device_rng_state'] = torch.cuda.get_rng_state_all() + else: + state['device_type'] = 'cpu' + state['device_rng_state'] = None + return state + + def _load_rng_state(self, rng_path): + rng_state = torch.load(rng_path, map_location='cpu', weights_only=False) + random.setstate(rng_state['python_rng_state']) + np.random.set_state(rng_state['numpy_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + + device_type = rng_state.get('device_type') + device_rng_state = rng_state.get('device_rng_state') + if device_type == 'npu' and hasattr(torch, 'npu') and torch.npu.is_available() and device_rng_state is not None: + torch.npu.set_rng_state(device_rng_state) + elif device_type == 'cuda' and torch.cuda.is_available() and device_rng_state is not None: + torch.cuda.set_rng_state_all(device_rng_state) + + @remote_function() + def read_training_progress(self, checkpoint_dir, **kwargs): + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if not os.path.exists(trainer_state_path): + raise FileNotFoundError(trainer_state_path) + + with open(trainer_state_path, 'r', encoding='utf-8') as f: + trainer_state = json.load(f) + + required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} + missing_keys = required_keys - trainer_state.keys() + if missing_keys: + raise ValueError(f'Missing trainer_state keys: {sorted(missing_keys)}') + return trainer_state + + @remote_function() + def load_training_state(self, checkpoint_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + required_paths = { + 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), + 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), + 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), + 'scaler': os.path.join(checkpoint_dir, 'scaler.pt'), + 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), + } + for path in required_paths.values(): + if not os.path.exists(path): + raise FileNotFoundError(path) + + trainer_state = self.read_training_progress(checkpoint_dir) + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) + self._load_scaler_state(required_paths['scaler'], adapter_name=adapter_name) + self._load_rng_state(required_paths['rng']) + + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + return trainer_state + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 535a3bd7..259c50c4 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -10,6 +10,8 @@ import torch import traceback +import os +from pathlib import Path from fastapi import Depends, FastAPI, HTTPException, Request from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -347,6 +349,55 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) + @app.post('/twinkle/load_training_state') + async def load_training_state( + request: Request, + body: types.LoadTrainingStateRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + self.model.load_training_state( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + + await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) + + @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) + async def read_training_progress( + request: Request, + body: types.ReadTrainingProgressRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.TrainingProgressResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + ret = self.model.read_training_progress( + checkpoint_dir, + adapter_name=adapter_name, + **extra_kwargs, + ) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='read_training_progress')) + @app.post('/twinkle/upload_to_hub') async def upload_to_hub( request: Request, diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 743125d9..2a8afdce 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -19,6 +19,7 @@ GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -188,6 +189,23 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def load_training_state(self, name: str, **kwargs) -> None: + """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + response = http_post( + url=f'{self.server_url}/load_training_state', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: + """Read progress-only checkpoint metadata for resume-only-model flows.""" + response = http_post( + url=f'{self.server_url}/read_training_progress', + json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 00b1f967..0e5d37e1 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -23,11 +23,13 @@ GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, + LoadTrainingStateRequest, LoadRequest, LoadResponse, LrStepResponse, ModelResult, OkResponse, + ReadTrainingProgressRequest, SaveRequest, SaveResponse, SetLossRequest, @@ -41,6 +43,7 @@ SetTemplateRequest, SetTemplateResponse, StepResponse, + TrainingProgressResponse, UploadToHubRequest, UploadToHubResponse, ZeroGradResponse, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e594bae4..0b6d7c08 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -89,6 +89,22 @@ class Config: extra = 'allow' +class LoadTrainingStateRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + +class ReadTrainingProgressRequest(BaseModel): + adapter_name: str + name: str + + class Config: + extra = 'allow' + + class AddAdapterRequest(BaseModel): adapter_name: str config: str @@ -212,6 +228,11 @@ class SaveResponse(BaseModel): checkpoint_dir: Optional[str] = None +class TrainingProgressResponse(BaseModel): + """Response for /read_training_progress endpoint (returns progress metadata).""" + result: Dict[str, Any] + + # --- Void responses (return None → OkResponse) --- class BackwardResponse(OkResponse): diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..232c6fe6 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,14 +1,29 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import IterableDataset as TorchIterableDataset + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor import twinkle from twinkle import DeviceMesh from twinkle.data_format import Message from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset from twinkle.processor import InputProcessor twinkle.initialize(mode='local') @@ -22,6 +37,36 @@ def convert_to_messages(example): return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]} +def _build_resume_rows(): + return [ + {'text': 'Hello world'}, + {'text': 'Test data'}, + {'text': 'Another example'}, + {'text': 'Sample text'}, + ] + + +class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +class _InMemoryIterableDataset(TorchIterableDataset): + + def __init__(self, rows): + self.rows = rows + + def __iter__(self): + return iter(self.rows) + + class TestDataLoaderBasic: def test_dataloader_basic(self): @@ -157,3 +202,25 @@ def test_retry_sampler_length(self): total_samples = sum(len(batch) for batch in dataloader) assert total_samples == original_len + + +class TestResumeSkip: + + def test_dataloader_skip_consumed_samples_for_map_style_dataset(self): + dataset = _InMemoryDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + batches = list(dataloader) + + texts = [item['text'] for batch in batches for item in batch] + assert texts[0] == 'Another example' + + def test_dataloader_warns_when_skip_requested_for_iterable_dataset(self, recwarn): + dataset = _InMemoryIterableDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + next(iter(dataloader)) + + assert 'does not support consumed-data skipping' in str(recwarn.pop(UserWarning).message) diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index b8438207..1dd7d7ca 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,10 +1,26 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import os +import numpy as np import pytest from pathlib import Path from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data import Dataset as TorchDataset + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor import twinkle +from twinkle import DeviceMesh from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta @@ -162,3 +178,33 @@ def test_sequential_vs_random_order(self): different = seq_texts != rand_texts assert different or len(seq_texts) == 1 + + +class TestResumeSkipSamplerOrdering: + + def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + dataset = _InMemoryDataset([ + {'text': 'Hello world'}, + {'text': 'Test data'}, + {'text': 'Another example'}, + {'text': 'Sample text'}, + ]) + sampler = SequentialSampler(dataset) + device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) + dataloader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler, device_mesh=device_mesh, num_workers=0) + + dataloader.skip_consumed_samples(2) + first_batch = list(dataloader)[0] + + assert first_batch[0]['text'] == 'Another example' diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py new file mode 100644 index 00000000..d616156a --- /dev/null +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -0,0 +1,301 @@ +import concurrent.futures +import sys +import types +import uuid +from pathlib import Path +from unittest.mock import Mock + +import pytest +from peft import LoraConfig +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast + + +class _NoOpProcessPoolExecutor: + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + +if 'zmq' not in sys.modules: + zmq_stub = types.ModuleType('zmq') + + class _ZmqError: + class Again(Exception): + pass + + class _ZmqSocket: + def setsockopt(self, *args, **kwargs): + pass + + def setsockopt_string(self, *args, **kwargs): + pass + + def bind(self, *args, **kwargs): + pass + + def connect(self, *args, **kwargs): + pass + + def send_string(self, *args, **kwargs): + pass + + def send_pyobj(self, *args, **kwargs): + pass + + def recv_string(self, *args, **kwargs): + return '' + + def recv_pyobj(self, *args, **kwargs): + return None + + class _ZmqContext: + def socket(self, *args, **kwargs): + return _ZmqSocket() + + zmq_stub.Context = _ZmqContext + zmq_stub.Socket = _ZmqSocket + zmq_stub.REQ = 0 + zmq_stub.REP = 1 + zmq_stub.PUB = 2 + zmq_stub.SUB = 3 + zmq_stub.SNDMORE = 4 + zmq_stub.IPV6 = 5 + zmq_stub.SUBSCRIBE = 6 + zmq_stub.RCVTIMEO = 7 + zmq_stub.SNDTIMEO = 8 + zmq_stub.LINGER = 9 + zmq_stub.error = _ZmqError + sys.modules['zmq'] = zmq_stub + +from twinkle import DeviceMesh +from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel + + +def build_tiny_tokenizer(): + vocab = {'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, 'hello': 4, 'world': 5} + backend = Tokenizer(WordLevel(vocab=vocab, unk_token='[UNK]')) + backend.pre_tokenizer = Whitespace() + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=backend, + pad_token='[PAD]', + bos_token='[BOS]', + eos_token='[EOS]', + unk_token='[UNK]', + ) + return tokenizer + + +@pytest.fixture +def tmp_path(): + root = Path(r'C:\Users\weika\.codex\memories') / 'twinkle-tests' + root.mkdir(exist_ok=True) + path = root / uuid.uuid4().hex + path.mkdir() + return path + + +@pytest.fixture +def tiny_local_model_dir(tmp_path): + model_dir = tmp_path / 'tiny-gpt2' + model_dir.mkdir() + + tokenizer = build_tiny_tokenizer() + tokenizer.save_pretrained(model_dir) + + config = GPT2Config( + vocab_size=tokenizer.vocab_size, + n_layer=1, + n_head=1, + n_embd=16, + n_positions=32, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + GPT2LMHeadModel(config).save_pretrained(model_dir) + return model_dir + + +def build_full_param_model(model_dir): + return TransformersModel( + model_cls='GPT2LMHeadModel', + model_id=str(model_dir), + mixed_precision='no', + grad_scaler_config={}, + ) + + +def build_multi_lora_model(model_dir): + device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) + model = MultiLoraTransformersModel( + model_cls='GPT2LMHeadModel', + model_id=str(model_dir), + mixed_precision='no', + device_mesh=device_mesh, + grad_scaler_config={}, + ) + model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['c_attn'])) + return model + + +def prepare_full_param_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): + model = build_full_param_model(model_dir) + model.set_optimizer('AdamW', lr=1e-4) + model.set_lr_scheduler('LinearLR') + model.set_grad_scaler(device='cpu') + model.optimizer_group[''].cur_step = cur_step + return Path( + model.save( + name='full-resume', + output_dir=str(tmp_path), + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + )) + + +def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): + model = build_multi_lora_model(model_dir) + model.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + model.set_lr_scheduler('LinearLR', adapter_name='default') + model.set_grad_scaler(adapter_name='default', device='cpu') + model.optimizer_group['default'].cur_step = cur_step + return Path( + model.save( + name='lora-resume', + output_dir=str(tmp_path), + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + )) + + +def run_resume_only_model_flow(model, dataloader, checkpoint_dir, ignore_data_skip): + if ignore_data_skip: + return None + progress = model.read_training_progress(str(checkpoint_dir)) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) + model.optimizer_group[''].cur_step = progress['cur_step'] + model.optimizer_group[''].gradient_accumulation_steps = progress['gradient_accumulation_steps'] + return progress + + +def _ensure_file_exists(path: Path): + if path.exists(): + return + if path.name == 'trainer_state.json': + path.write_text('{"cur_step": 3}', encoding='utf-8') + else: + path.write_bytes(b'placeholder') + + +def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): + model = build_full_param_model(tiny_local_model_dir) + model.set_optimizer('AdamW', lr=1e-4) + model.set_lr_scheduler('LinearLR') + model.set_grad_scaler(device='cpu') + model.optimizer_group[''].cur_step = 7 + + ckpt_dir = Path(model.save(name='resume-step', output_dir=str(tmp_path), save_optimizer=True)) + + assert (ckpt_dir / 'optimizer.pt').exists() + assert (ckpt_dir / 'scheduler.pt').exists() + assert (ckpt_dir / 'scaler.pt').exists() + assert (ckpt_dir / 'trainer_state.json').exists() + assert (ckpt_dir / 'rng_state.pt').exists() + + +@pytest.mark.parametrize( + 'missing_name, expected_pattern', + [ + ('optimizer.pt', 'optimizer.pt'), + ('scheduler.pt', 'scheduler.pt'), + ('scaler.pt', 'scaler.pt'), + ('rng_state.pt', 'rng_state.pt'), + ('trainer_state.json', 'trainer_state.json'), + ], +) +def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, + expected_pattern): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) + _ensure_file_exists(ckpt_dir / missing_name) + (ckpt_dir / missing_name).unlink() + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + with pytest.raises(FileNotFoundError, match=expected_pattern): + restored.load_training_state(str(ckpt_dir)) + + +def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) + (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + with pytest.raises((KeyError, ValueError), match='gradient_accumulation_steps|consumed_train_samples'): + restored.load_training_state(str(ckpt_dir)) + + +def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + restored.set_grad_scaler(device='cpu') + + trainer_state = restored.load_training_state(str(ckpt_dir)) + + assert trainer_state['cur_step'] == 9 + assert restored.optimizer_group[''].cur_step == 9 + + +def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_lora_checkpoint(tmp_path, tiny_local_model_dir, cur_step=5, consumed_train_samples=10) + restored = build_multi_lora_model(tiny_local_model_dir) + restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') + restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + restored.set_lr_scheduler('LinearLR', adapter_name='default') + restored.set_grad_scaler(adapter_name='default', device='cpu') + + trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') + + assert trainer_state['cur_step'] == 5 + + +def test_read_training_progress_supports_resume_only_model(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) + restored = build_full_param_model(ckpt_dir) + dataloader = Mock() + + progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=False) + + assert progress['cur_step'] == 6 + assert progress['consumed_train_samples'] == 12 + assert restored.optimizer_group[''].cur_step == 6 + dataloader.skip_consumed_samples.assert_called_once_with(12) + + +def test_resume_only_model_ignore_data_skip_leaves_progress_unrestored(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) + restored = build_full_param_model(ckpt_dir) + dataloader = Mock() + restored.read_training_progress = Mock(wraps=restored.read_training_progress) + + progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=True) + + assert progress is None + assert restored.optimizer_group[''].cur_step == 0 + restored.read_training_progress.assert_not_called() + dataloader.skip_consumed_samples.assert_not_called() diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py new file mode 100644 index 00000000..7932ca20 --- /dev/null +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -0,0 +1,180 @@ +import concurrent.futures +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import Mock + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +class _NoOpProcessPoolExecutor: + + def __init__(self, *args, **kwargs): + pass + + def submit(self, fn, *args, **kwargs): + raise RuntimeError('Process pool is disabled in this test environment.') + + +concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor + +if 'tinker' not in sys.modules: + tinker_module = types.ModuleType('tinker') + tinker_types_module = types.ModuleType('tinker.types') + + class _TinkerPlaceholder: + pass + + for name in ( + 'CreateModelRequest', + 'TrainingRun', + 'TrainingRunsResponse', + 'Cursor', + 'Checkpoint', + 'CheckpointsListResponse', + 'ParsedCheckpointTinkerPath', + 'WeightsInfoResponse', + ): + setattr(tinker_types_module, name, _TinkerPlaceholder) + tinker_module.types = tinker_types_module + sys.modules['tinker'] = tinker_module + sys.modules['tinker.types'] = tinker_types_module + +if 'twinkle.server.common' not in sys.modules: + common_module = types.ModuleType('twinkle.server.common') + checkpoint_factory_module = types.ModuleType('twinkle.server.common.checkpoint_factory') + checkpoint_factory_module.create_checkpoint_manager = lambda token, client_type='twinkle': None + checkpoint_factory_module.create_training_run_manager = lambda token, client_type='twinkle': None + common_module.checkpoint_factory = checkpoint_factory_module + sys.modules['twinkle.server.common'] = common_module + sys.modules['twinkle.server.common.checkpoint_factory'] = checkpoint_factory_module + +from twinkle_client.types.checkpoint import ResolvedLoadPath + +_HANDLERS_PATH = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' +_HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) +handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) +sys.modules[_HANDLERS_SPEC.name] = handlers +_HANDLERS_SPEC.loader.exec_module(handlers) + + +class _FakeCheckpointManager: + + def resolve_load_path(self, path: str) -> ResolvedLoadPath: + return ResolvedLoadPath( + checkpoint_name='ckpt-1', + checkpoint_dir='D:/resolved/weights', + is_twinkle_path=True, + training_run_id='run-1', + checkpoint_id='weights/ckpt-1', + ) + + +class _FakeModelManagement: + + def __init__(self): + self.model = Mock() + + async def _on_request_start(self, request): + request.state.request_id = 'req-1' + return 'token-1' + + def assert_resource_exists(self, adapter_name): + return None + + async def schedule_task_and_wait(self, task, task_type=''): + return await task() + + +def _build_test_client(monkeypatch): + management = _FakeModelManagement() + checkpoint_manager = _FakeCheckpointManager() + monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) + + app = FastAPI() + handlers._register_twinkle_routes(app, lambda: management) + return TestClient(app), management + + +def _run_remote_resume_case(client: TestClient, *, resume_from_checkpoint, resume_only_model, ignore_data_skip): + if resume_from_checkpoint is None: + return None + if not resume_only_model: + return client.post('/twinkle/load_training_state', json={'name': resume_from_checkpoint, 'adapter_name': ''}) + if not ignore_data_skip: + return client.post('/twinkle/read_training_progress', json={'name': resume_from_checkpoint, 'adapter_name': ''}) + return None + + +def test_case_1_no_resume_call_leaves_remote_resume_helpers_unused(monkeypatch): + """Case 1: resume_from_checkpoint is None, so no remote resume helper should be called.""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint=None, + resume_only_model=False, + ignore_data_skip=False, + ) + + assert response is None + management.model.load_training_state.assert_not_called() + management.model.read_training_progress.assert_not_called() + + +def test_case_2_resume_only_model_false_uses_load_training_state_route(monkeypatch): + """Case 2: resume_only_model=False should use load_training_state().""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=False, + ignore_data_skip=False, + ) + + assert response.status_code == 200 + management.model.load_training_state.assert_called_once_with( + 'D:/resolved/weights/ckpt-1', + adapter_name=None, + ) + management.model.read_training_progress.assert_not_called() + + +def test_case_3_resume_only_model_true_without_ignore_data_skip_reads_progress_only(monkeypatch): + """Case 3: resume_only_model=True and ignore_data_skip=False should use read_training_progress() only.""" + client, management = _build_test_client(monkeypatch) + management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=True, + ignore_data_skip=False, + ) + + assert response.status_code == 200 + assert response.json()['result']['consumed_train_samples'] == 12 + management.model.read_training_progress.assert_called_once_with( + 'D:/resolved/weights/ckpt-1', + adapter_name=None, + ) + management.model.load_training_state.assert_not_called() + + +def test_case_4_resume_only_model_true_with_ignore_data_skip_uses_neither_helper(monkeypatch): + """Case 4: resume_only_model=True and ignore_data_skip=True should call neither remote helper.""" + client, management = _build_test_client(monkeypatch) + + response = _run_remote_resume_case( + client, + resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + resume_only_model=True, + ignore_data_skip=True, + ) + + assert response is None + management.model.load_training_state.assert_not_called() + management.model.read_training_progress.assert_not_called() From d41a634f3e02d8cb60837c79b24894f5283ac316 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 11:51:04 +0800 Subject: [PATCH 07/60] wip --- .gitignore | 1 + .../transformers/test_checkpoint_resume.py | 166 ++++++++---------- .../model/test_twinkle_resume_routes.py | 135 ++++++-------- 3 files changed, 123 insertions(+), 179 deletions(-) diff --git a/.gitignore b/.gitignore index afdfcae9..ae4d2cb3 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ wheels/ /temp MANIFEST .locks/ +tmp_test_checkpoints/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index d616156a..f6b5b6a1 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -3,14 +3,14 @@ import types import uuid from pathlib import Path -from unittest.mock import Mock +from types import ModuleType import pytest from peft import LoraConfig from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace -from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast +from transformers import Qwen3Config, Qwen3ForCausalLM, PreTrainedTokenizerFast class _NoOpProcessPoolExecutor: @@ -23,58 +23,21 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -if 'zmq' not in sys.modules: - zmq_stub = types.ModuleType('zmq') - - class _ZmqError: - class Again(Exception): - pass - - class _ZmqSocket: - def setsockopt(self, *args, **kwargs): - pass - - def setsockopt_string(self, *args, **kwargs): - pass - - def bind(self, *args, **kwargs): - pass - - def connect(self, *args, **kwargs): - pass - - def send_string(self, *args, **kwargs): - pass - - def send_pyobj(self, *args, **kwargs): - pass +ROOT = Path(__file__).resolve().parents[3] +SRC = ROOT / 'src' +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) - def recv_string(self, *args, **kwargs): - return '' - - def recv_pyobj(self, *args, **kwargs): - return None - - class _ZmqContext: - def socket(self, *args, **kwargs): - return _ZmqSocket() - - zmq_stub.Context = _ZmqContext - zmq_stub.Socket = _ZmqSocket - zmq_stub.REQ = 0 - zmq_stub.REP = 1 - zmq_stub.PUB = 2 - zmq_stub.SUB = 3 - zmq_stub.SNDMORE = 4 - zmq_stub.IPV6 = 5 - zmq_stub.SUBSCRIBE = 6 - zmq_stub.RCVTIMEO = 7 - zmq_stub.SNDTIMEO = 8 - zmq_stub.LINGER = 9 - zmq_stub.error = _ZmqError - sys.modules['zmq'] = zmq_stub +if 'zmq' not in sys.modules: + fake_zmq = ModuleType('zmq') + fake_zmq.Socket = object + fake_zmq.Context = object + fake_zmq.REP = 0 + fake_zmq.REQ = 1 + fake_zmq.IPV6 = 2 + fake_zmq.error = types.SimpleNamespace(Again=RuntimeError) + sys.modules['zmq'] = fake_zmq -from twinkle import DeviceMesh from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel @@ -94,7 +57,7 @@ def build_tiny_tokenizer(): @pytest.fixture def tmp_path(): - root = Path(r'C:\Users\weika\.codex\memories') / 'twinkle-tests' + root = Path(__file__).parent / 'tmp_test_checkpoints' root.mkdir(exist_ok=True) path = root / uuid.uuid4().hex path.mkdir() @@ -103,28 +66,28 @@ def tmp_path(): @pytest.fixture def tiny_local_model_dir(tmp_path): - model_dir = tmp_path / 'tiny-gpt2' + model_dir = tmp_path / 'tiny-qwen3' model_dir.mkdir() tokenizer = build_tiny_tokenizer() tokenizer.save_pretrained(model_dir) - config = GPT2Config( + config = Qwen3Config( vocab_size=tokenizer.vocab_size, - n_layer=1, - n_head=1, - n_embd=16, - n_positions=32, + num_hidden_layers=1, + num_attention_heads=1, + hidden_size=16, + max_position_embeddings=32, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) - GPT2LMHeadModel(config).save_pretrained(model_dir) + Qwen3ForCausalLM(config).save_pretrained(model_dir) return model_dir def build_full_param_model(model_dir): return TransformersModel( - model_cls='GPT2LMHeadModel', + model_cls='Qwen3ForCausalLM', model_id=str(model_dir), mixed_precision='no', grad_scaler_config={}, @@ -134,13 +97,13 @@ def build_full_param_model(model_dir): def build_multi_lora_model(model_dir): device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) model = MultiLoraTransformersModel( - model_cls='GPT2LMHeadModel', + model_cls='Qwen3ForCausalLM', model_id=str(model_dir), mixed_precision='no', device_mesh=device_mesh, grad_scaler_config={}, ) - model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['c_attn'])) + model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) return model @@ -175,16 +138,6 @@ def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samp )) -def run_resume_only_model_flow(model, dataloader, checkpoint_dir, ignore_data_skip): - if ignore_data_skip: - return None - progress = model.read_training_progress(str(checkpoint_dir)) - dataloader.skip_consumed_samples(progress['consumed_train_samples']) - model.optimizer_group[''].cur_step = progress['cur_step'] - model.optimizer_group[''].gradient_accumulation_steps = progress['gradient_accumulation_steps'] - return progress - - def _ensure_file_exists(path: Path): if path.exists(): return @@ -250,6 +203,19 @@ def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_lo def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) + saved = build_full_param_model(tiny_local_model_dir) + saved.set_optimizer('AdamW', lr=1e-4) + saved.set_lr_scheduler('LinearLR') + saved.set_grad_scaler(device='cpu') + saved.optimizer_group[''].cur_step = 9 + saved.optimizer_group[''].gradient_accumulation_steps = 4 + ckpt_dir = Path( + saved.save( + name='full-resume', + output_dir=str(tmp_path), + save_optimizer=True, + consumed_train_samples=18, + )) restored = build_full_param_model(ckpt_dir) restored.set_optimizer('AdamW', lr=1e-4) restored.set_lr_scheduler('LinearLR') @@ -258,13 +224,36 @@ def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny trainer_state = restored.load_training_state(str(ckpt_dir)) assert trainer_state['cur_step'] == 9 + assert trainer_state['gradient_accumulation_steps'] == 4 + assert trainer_state['consumed_train_samples'] == 18 assert restored.optimizer_group[''].cur_step == 9 + assert restored.optimizer_group[''].gradient_accumulation_steps == 4 + + +def test_lora_load_does_not_restore_training_state_without_explicit_resume(tmp_path, tiny_local_model_dir): + saved = build_multi_lora_model(tiny_local_model_dir) + saved.set_optimizer('AdamW', adapter_name='default', lr=1e-4) + saved.set_lr_scheduler('LinearLR', adapter_name='default') + saved.set_grad_scaler(adapter_name='default', device='cpu') + saved.optimizer_group['default'].cur_step = 5 + saved.optimizer_group['default'].gradient_accumulation_steps = 3 + ckpt_dir = Path( + saved.save( + name='lora-resume', + output_dir=str(tmp_path), + adapter_name='default', + save_optimizer=True, + consumed_train_samples=10, + )) + restored = build_multi_lora_model(tiny_local_model_dir) + assert restored.optimizer_group['default'].cur_step == 0 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 -def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_lora_checkpoint(tmp_path, tiny_local_model_dir, cur_step=5, consumed_train_samples=10) - restored = build_multi_lora_model(tiny_local_model_dir) restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') + assert restored.optimizer_group['default'].cur_step == 0 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 + restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) restored.set_lr_scheduler('LinearLR', adapter_name='default') restored.set_grad_scaler(adapter_name='default', device='cpu') @@ -272,30 +261,19 @@ def test_lora_resume_keeps_adapter_load_separate_from_training_state(tmp_path, t trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') assert trainer_state['cur_step'] == 5 + assert trainer_state['gradient_accumulation_steps'] == 3 + assert trainer_state['consumed_train_samples'] == 10 + assert restored.optimizer_group['default'].cur_step == 5 + assert restored.optimizer_group['default'].gradient_accumulation_steps == 3 -def test_read_training_progress_supports_resume_only_model(tmp_path, tiny_local_model_dir): +def test_read_training_progress_returns_metadata_without_mutating_optimizer_state(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) restored = build_full_param_model(ckpt_dir) - dataloader = Mock() - progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=False) + progress = restored.read_training_progress(str(ckpt_dir)) assert progress['cur_step'] == 6 assert progress['consumed_train_samples'] == 12 - assert restored.optimizer_group[''].cur_step == 6 - dataloader.skip_consumed_samples.assert_called_once_with(12) - - -def test_resume_only_model_ignore_data_skip_leaves_progress_unrestored(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) - restored = build_full_param_model(ckpt_dir) - dataloader = Mock() - restored.read_training_progress = Mock(wraps=restored.read_training_progress) - - progress = run_resume_only_model_flow(restored, dataloader, ckpt_dir, ignore_data_skip=True) - - assert progress is None assert restored.optimizer_group[''].cur_step == 0 - restored.read_training_progress.assert_not_called() - dataloader.skip_consumed_samples.assert_not_called() + assert restored.optimizer_group[''].gradient_accumulation_steps == 1 diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py index 7932ca20..f3523913 100644 --- a/tests/server/model/test_twinkle_resume_routes.py +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -1,8 +1,8 @@ import concurrent.futures import importlib.util import sys -import types from pathlib import Path +from types import ModuleType from unittest.mock import Mock from fastapi import FastAPI @@ -20,40 +20,20 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -if 'tinker' not in sys.modules: - tinker_module = types.ModuleType('tinker') - tinker_types_module = types.ModuleType('tinker.types') +ROOT = Path(__file__).resolve().parents[3] +SRC = ROOT / 'src' +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) - class _TinkerPlaceholder: - pass - - for name in ( - 'CreateModelRequest', - 'TrainingRun', - 'TrainingRunsResponse', - 'Cursor', - 'Checkpoint', - 'CheckpointsListResponse', - 'ParsedCheckpointTinkerPath', - 'WeightsInfoResponse', - ): - setattr(tinker_types_module, name, _TinkerPlaceholder) - tinker_module.types = tinker_types_module - sys.modules['tinker'] = tinker_module - sys.modules['tinker.types'] = tinker_types_module - -if 'twinkle.server.common' not in sys.modules: - common_module = types.ModuleType('twinkle.server.common') - checkpoint_factory_module = types.ModuleType('twinkle.server.common.checkpoint_factory') - checkpoint_factory_module.create_checkpoint_manager = lambda token, client_type='twinkle': None - checkpoint_factory_module.create_training_run_manager = lambda token, client_type='twinkle': None - common_module.checkpoint_factory = checkpoint_factory_module - sys.modules['twinkle.server.common'] = common_module - sys.modules['twinkle.server.common.checkpoint_factory'] = checkpoint_factory_module +if 'twinkle.server.common.checkpoint_factory' not in sys.modules: + fake_checkpoint_factory = ModuleType('twinkle.server.common.checkpoint_factory') + fake_checkpoint_factory.create_checkpoint_manager = lambda *args, **kwargs: None + fake_checkpoint_factory.create_training_run_manager = lambda *args, **kwargs: None + sys.modules['twinkle.server.common.checkpoint_factory'] = fake_checkpoint_factory from twinkle_client.types.checkpoint import ResolvedLoadPath -_HANDLERS_PATH = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' +_HANDLERS_PATH = SRC / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' _HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) sys.modules[_HANDLERS_SPEC.name] = handlers @@ -62,10 +42,13 @@ class _TinkerPlaceholder: class _FakeCheckpointManager: + def __init__(self, checkpoint_dir='./resolved/weights'): + self._checkpoint_dir = checkpoint_dir + def resolve_load_path(self, path: str) -> ResolvedLoadPath: return ResolvedLoadPath( checkpoint_name='ckpt-1', - checkpoint_dir='D:/resolved/weights', + checkpoint_dir=self._checkpoint_dir, is_twinkle_path=True, training_run_id='run-1', checkpoint_id='weights/ckpt-1', @@ -88,9 +71,9 @@ async def schedule_task_and_wait(self, task, task_type=''): return await task() -def _build_test_client(monkeypatch): +def _build_test_client(monkeypatch, checkpoint_manager=None): management = _FakeModelManagement() - checkpoint_manager = _FakeCheckpointManager() + checkpoint_manager = checkpoint_manager or _FakeCheckpointManager() monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) app = FastAPI() @@ -98,83 +81,65 @@ def _build_test_client(monkeypatch): return TestClient(app), management -def _run_remote_resume_case(client: TestClient, *, resume_from_checkpoint, resume_only_model, ignore_data_skip): - if resume_from_checkpoint is None: - return None - if not resume_only_model: - return client.post('/twinkle/load_training_state', json={'name': resume_from_checkpoint, 'adapter_name': ''}) - if not ignore_data_skip: - return client.post('/twinkle/read_training_progress', json={'name': resume_from_checkpoint, 'adapter_name': ''}) - return None - - -def test_case_1_no_resume_call_leaves_remote_resume_helpers_unused(monkeypatch): - """Case 1: resume_from_checkpoint is None, so no remote resume helper should be called.""" +def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monkeypatch): client, management = _build_test_client(monkeypatch) - response = _run_remote_resume_case( - client, - resume_from_checkpoint=None, - resume_only_model=False, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, ) - assert response is None - management.model.load_training_state.assert_not_called() + assert response.status_code == 200 + management.model.load_training_state.assert_called_once_with( + 'resolved/weights/ckpt-1', + adapter_name=None, + ) management.model.read_training_progress.assert_not_called() -def test_case_2_resume_only_model_false_uses_load_training_state_route(monkeypatch): - """Case 2: resume_only_model=False should use load_training_state().""" +def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): client, management = _build_test_client(monkeypatch) - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=False, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': 'adapter-a'}, ) assert response.status_code == 200 management.model.load_training_state.assert_called_once_with( - 'D:/resolved/weights/ckpt-1', - adapter_name=None, + 'resolved/weights/ckpt-1', + adapter_name='req-1-adapter-a', ) - management.model.read_training_progress.assert_not_called() -def test_case_3_resume_only_model_true_without_ignore_data_skip_reads_progress_only(monkeypatch): - """Case 3: resume_only_model=True and ignore_data_skip=False should use read_training_progress() only.""" - client, management = _build_test_client(monkeypatch) - management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} +def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(monkeypatch): + client, management = _build_test_client(monkeypatch, checkpoint_manager=_FakeCheckpointManager(checkpoint_dir=None)) - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=True, - ignore_data_skip=False, + response = client.post( + '/twinkle/load_training_state', + json={'name': 'local-checkpoint-dir', 'adapter_name': ''}, ) assert response.status_code == 200 - assert response.json()['result']['consumed_train_samples'] == 12 - management.model.read_training_progress.assert_called_once_with( - 'D:/resolved/weights/ckpt-1', + management.model.load_training_state.assert_called_once_with( + 'local-checkpoint-dir', adapter_name=None, ) - management.model.load_training_state.assert_not_called() -def test_case_4_resume_only_model_true_with_ignore_data_skip_uses_neither_helper(monkeypatch): - """Case 4: resume_only_model=True and ignore_data_skip=True should call neither remote helper.""" +def test_read_training_progress_route_returns_progress_and_calls_model(monkeypatch): client, management = _build_test_client(monkeypatch) + management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} - response = _run_remote_resume_case( - client, - resume_from_checkpoint='twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - resume_only_model=True, - ignore_data_skip=True, + response = client.post( + '/twinkle/read_training_progress', + json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, ) - assert response is None + assert response.status_code == 200 + assert response.json()['result']['consumed_train_samples'] == 12 + management.model.read_training_progress.assert_called_once_with( + 'resolved/weights/ckpt-1', + adapter_name=None, + ) management.model.load_training_state.assert_not_called() - management.model.read_training_progress.assert_not_called() From 21f9918c0d7764cf091e5db94770f62d534c9f96 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 15:29:27 +0800 Subject: [PATCH 08/60] wip --- src/twinkle/model/transformers/transformers.py | 5 +++-- .../model/transformers/test_checkpoint_resume.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d588b8e7..e2cd9206 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1064,7 +1064,6 @@ def load_training_state(self, checkpoint_dir, **kwargs): 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), - 'scaler': os.path.join(checkpoint_dir, 'scaler.pt'), 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), } for path in required_paths.values(): @@ -1073,7 +1072,9 @@ def load_training_state(self, checkpoint_dir, **kwargs): trainer_state = self.read_training_progress(checkpoint_dir) self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) - self._load_scaler_state(required_paths['scaler'], adapter_name=adapter_name) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) self._load_rng_state(required_paths['rng']) optimizer_config.cur_step = trainer_state['cur_step'] diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index f6b5b6a1..7541cbaa 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -168,7 +168,6 @@ def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): [ ('optimizer.pt', 'optimizer.pt'), ('scheduler.pt', 'scheduler.pt'), - ('scaler.pt', 'scaler.pt'), ('rng_state.pt', 'rng_state.pt'), ('trainer_state.json', 'trainer_state.json'), ], @@ -188,6 +187,21 @@ def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_loc restored.load_training_state(str(ckpt_dir)) +def test_load_training_state_allows_missing_scaler_file(tmp_path, tiny_local_model_dir): + ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=8, consumed_train_samples=16) + (ckpt_dir / 'scaler.pt').unlink() + + restored = build_full_param_model(ckpt_dir) + restored.set_optimizer('AdamW', lr=1e-4) + restored.set_lr_scheduler('LinearLR') + + trainer_state = restored.load_training_state(str(ckpt_dir)) + + assert trainer_state['cur_step'] == 8 + assert trainer_state['consumed_train_samples'] == 16 + assert restored.optimizer_group[''].cur_step == 8 + + def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') From 1e595317e71baca5e08ef8102c2087ea1d308421 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 15:51:40 +0800 Subject: [PATCH 09/60] fix --- src/twinkle/model/transformers/transformers.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index e2cd9206..1c4c57fd 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -148,6 +148,19 @@ def calculate_metrics(self, is_training): return results +def _normalize_checkpoint_state(value: Any): + """Convert nested DTensor state into plain CPU tensors for checkpointing.""" + if isinstance(value, dict): + return {k: _normalize_checkpoint_state(v) for k, v in value.items()} + if isinstance(value, list): + return [_normalize_checkpoint_state(v) for v in value] + if isinstance(value, tuple): + return tuple(_normalize_checkpoint_state(v) for v in value) + if torch.is_tensor(value): + return Torch.to_local_tensor(value).cpu() + return value + + _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -884,7 +897,8 @@ def _save_optimizer(self, output_dir, **kwargs): optimizer = optimizer_config.optimizer lr_scheduler = optimizer_config.lr_scheduler if optimizer is not None: - torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) + state_dict = _normalize_checkpoint_state(optimizer.state_dict()) + torch.save(state_dict, os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) From 9bb3f39e0ad7c0cdf61cc1c787336f97d78bccaa Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 16:13:05 +0800 Subject: [PATCH 10/60] wip --- .../model/transformers/strategy/accelerate.py | 49 +++++++++++++++++++ .../model/transformers/transformers.py | 36 ++++++-------- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 89b497e2..214623fa 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,6 +124,55 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _prepare_fsdp2_sd_options(self): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: + return None + + from torch.distributed.checkpoint.state_dict import StateDictOptions + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + return StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, 'offload_to_cpu', False), + broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, 'rank0_only', False), + ) + + def needs_wrapped_optimizer_state(self) -> bool: + fsdp_plugin = self.accelerator.state.fsdp_plugin + return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + import torch + + optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) + if self.accelerator.process_index == 0: + torch.save(optim_state, output_path) + return + + import torch + if self.accelerator.process_index == 0: + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + fsdp_plugin = self.accelerator.state.fsdp_plugin + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + import torch + + optim_state = None + rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) + if self.accelerator.process_index == 0 or not rank0_only: + optim_state = torch.load(input_path, weights_only=True) + set_optimizer_state_dict(model, optimizer, optim_state, options=self._prepare_fsdp2_sd_options()) + return + + import torch + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + def get_full_state_dict(self, model) -> dict: """Collect full state dict.""" from twinkle.utils import torch_util diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 1c4c57fd..40473420 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -147,20 +147,6 @@ def calculate_metrics(self, is_training): self.outputs = None return results - -def _normalize_checkpoint_state(value: Any): - """Convert nested DTensor state into plain CPU tensors for checkpointing.""" - if isinstance(value, dict): - return {k: _normalize_checkpoint_state(v) for k, v in value.items()} - if isinstance(value, list): - return [_normalize_checkpoint_state(v) for v in value] - if isinstance(value, tuple): - return tuple(_normalize_checkpoint_state(v) for v in value) - if torch.is_tensor(value): - return Torch.to_local_tensor(value).cpu() - return value - - _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -892,13 +878,16 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int def _save_optimizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + optimizer = optimizer_config.optimizer + lr_scheduler = optimizer_config.lr_scheduler + if optimizer is not None: + optimizer_path = os.path.join(output_dir, 'optimizer.pt') + if hasattr(self.strategy, 'save_optimizer_checkpoint'): + self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) + elif Platform.is_master(): + torch.save(optimizer.state_dict(), optimizer_path) if Platform.is_master(): - optimizer = optimizer_config.optimizer - lr_scheduler = optimizer_config.lr_scheduler - if optimizer is not None: - state_dict = _normalize_checkpoint_state(optimizer.state_dict()) - torch.save(state_dict, os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) @@ -1007,8 +996,13 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): raise FileNotFoundError(scheduler_path) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) - optimizer_config.optimizer.load_state_dict(state_dict) + if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: + self._lazy_wrap_model() + if hasattr(self.strategy, 'load_optimizer_checkpoint'): + self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) + else: + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) + optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) From fdf1f71942be0627d8f12f09ab22abdb7a98d555 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 16:54:01 +0800 Subject: [PATCH 11/60] fix --- src/twinkle/model/transformers/transformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 40473420..d7e8183a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -961,10 +961,11 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): model_sd = model.state_dict() converted_weights = {} for key, value in adapter_weights.items(): - if f'.{adapter_name}.weight' not in key: - key = key.replace('.weight', f'.{adapter_name}.weight') - if key in model_sd: - param = model_sd[key] + model_key = key + if f'.{adapter_name}.weight' not in model_key: + model_key = model_key.replace('.weight', f'.{adapter_name}.weight') + if model_key in model_sd: + param = model_sd[model_key] if isinstance(param, DTensor) and not isinstance(value, DTensor): value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) converted_weights[key] = value From 6cf51606ade61c7a0d4a07190e589d22d8c8fc68 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 30 Mar 2026 18:34:31 +0800 Subject: [PATCH 12/60] wip --- .../model/transformers/strategy/accelerate.py | 12 +++-- .../transformers/strategy/native_fsdp.py | 47 +++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 214623fa..bcfd5e30 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,8 +124,12 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _get_fsdp_plugin(self): + state = self.accelerator.state + return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None + def _prepare_fsdp2_sd_options(self): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: return None @@ -139,11 +143,11 @@ def _prepare_fsdp2_sd_options(self): ) def needs_wrapped_optimizer_state(self) -> bool: - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 def save_optimizer_checkpoint(self, model, optimizer, output_path: str): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict import torch @@ -158,7 +162,7 @@ def save_optimizer_checkpoint(self, model, optimizer, output_path: str): torch.save(optimizer.state_dict(), output_path) def load_optimizer_checkpoint(self, model, optimizer, input_path: str): - fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict import torch diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 48a1da85..ad675006 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -151,6 +151,53 @@ def wrap_model(self, model, optimizer=None): return model, optimizer + def _prepare_optimizer_state_dict_options(self, *, for_load: bool): + from torch.distributed.checkpoint.state_dict import StateDictOptions + + return StateDictOptions( + full_state_dict=True, + cpu_offload=not for_load, + broadcast_from_rank0=for_load, + ) + + def needs_wrapped_optimizer_state(self) -> bool: + return self.device_mesh is not None + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + return + + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict( + model, + optimizer, + options=self._prepare_optimizer_state_dict_options(for_load=False), + ) + if Platform.is_master(): + torch.save(optim_state, output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + return + + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = {} + if Platform.is_master(): + optim_state = torch.load(input_path, map_location='cpu', weights_only=True) + set_optimizer_state_dict( + model, + optimizer, + optim_state, + options=self._prepare_optimizer_state_dict_options(for_load=True), + ) + def get_ep_clip_kwargs(self, model) -> Dict[str, Any]: """Return EP-aware kwargs for normalize_and_clip_grad_norm.""" model = self.unwrap_model(model) From e21f870c395215b0055fcf20e42645ab95178f46 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 09:52:22 +0800 Subject: [PATCH 13/60] lint --- ...7-transformers-checkpoint-resume-design.md | 1 - .../model/transformers/strategy/accelerate.py | 4 +- .../model/transformers/transformers.py | 5 +- src/twinkle/server/model/twinkle_handlers.py | 14 ++- tests/dataloader/test_dataloader.py | 16 ++- tests/dataloader/test_sampler.py | 21 +++- .../transformers/test_checkpoint_resume.py | 103 +++++++++++++++++- .../model/test_twinkle_resume_routes.py | 25 +++-- 8 files changed, 155 insertions(+), 34 deletions(-) diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md index 7b90baba..4f41c64f 100644 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md @@ -430,4 +430,3 @@ Recommended guidance text: - `ignore_data_skip=True` disables progress restore and starts from step 0 - Full-parameter checkpoints restore weights during model initialization and restore training state afterward - Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data - diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index bcfd5e30..b9b81f59 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -149,8 +149,8 @@ def needs_wrapped_optimizer_state(self) -> bool: def save_optimizer_checkpoint(self, model, optimizer, output_path: str): fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict import torch + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) if self.accelerator.process_index == 0: @@ -164,8 +164,8 @@ def save_optimizer_checkpoint(self, model, optimizer, output_path: str): def load_optimizer_checkpoint(self, model, optimizer, input_path: str): fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict import torch + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict optim_state = None rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d7e8183a..88366c96 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -2,11 +2,11 @@ import asyncio import contextlib import json +import numpy as np import os import random import re import threading -import numpy as np import torch import torch.distributed as dist import transformers @@ -147,6 +147,7 @@ def calculate_metrics(self, is_training): self.outputs = None return results + _default_adapter_name = '' DEFAULT_LEARNING_RATE = 1e-5 DEFAULT_WEIGHT_DECAY = 0.01 @@ -1055,7 +1056,7 @@ def read_training_progress(self, checkpoint_dir, **kwargs): if not os.path.exists(trainer_state_path): raise FileNotFoundError(trainer_state_path) - with open(trainer_state_path, 'r', encoding='utf-8') as f: + with open(trainer_state_path, encoding='utf-8') as f: trainer_state = json.load(f) required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 259c50c4..b48b4e12 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,11 +8,11 @@ """ from __future__ import annotations +import os import torch import traceback -import os -from pathlib import Path from fastapi import Depends, FastAPI, HTTPException, Request +from pathlib import Path from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -363,8 +363,9 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) - checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() - if resolved.checkpoint_dir else body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) self.model.load_training_state( checkpoint_dir, adapter_name=adapter_name, @@ -387,8 +388,9 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) - checkpoint_dir = (Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() - if resolved.checkpoint_dir else body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) ret = self.model.read_training_progress( checkpoint_dir, adapter_name=adapter_name, diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 232c6fe6..82a4f41b 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -39,10 +39,18 @@ def convert_to_messages(example): def _build_resume_rows(): return [ - {'text': 'Hello world'}, - {'text': 'Test data'}, - {'text': 'Another example'}, - {'text': 'Sample text'}, + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, ] diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index 1dd7d7ca..d5c97dbc 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import concurrent.futures -import os import numpy as np +import os import pytest from pathlib import Path -from torch.utils.data import RandomSampler, SequentialSampler from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import RandomSampler, SequentialSampler class _NoOpProcessPoolExecutor: @@ -183,6 +183,7 @@ def test_sequential_vs_random_order(self): class TestResumeSkipSamplerOrdering: def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + class _InMemoryDataset(TorchDataset): def __init__(self, rows): @@ -195,10 +196,18 @@ def __getitem__(self, idx): return self.rows[idx] dataset = _InMemoryDataset([ - {'text': 'Hello world'}, - {'text': 'Test data'}, - {'text': 'Another example'}, - {'text': 'Sample text'}, + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, ]) sampler = SequentialSampler(dataset) device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py index 7541cbaa..3e903d82 100644 --- a/tests/model/transformers/test_checkpoint_resume.py +++ b/tests/model/transformers/test_checkpoint_resume.py @@ -1,19 +1,20 @@ import concurrent.futures +import pytest import sys +import torch import types import uuid from pathlib import Path -from types import ModuleType - -import pytest from peft import LoraConfig from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace -from transformers import Qwen3Config, Qwen3ForCausalLM, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerFast, Qwen3Config, Qwen3ForCausalLM +from types import ModuleType class _NoOpProcessPoolExecutor: + def __init__(self, *args, **kwargs): pass @@ -39,6 +40,7 @@ def submit(self, fn, *args, **kwargs): sys.modules['zmq'] = fake_zmq from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel +from twinkle.model.transformers.strategy import NativeFSDPStrategy def build_tiny_tokenizer(): @@ -94,6 +96,16 @@ def build_full_param_model(model_dir): ) +def build_native_fsdp_strategy(): + device_mesh = types.SimpleNamespace( + world_size=2, + ep_size=1, + ep_fsdp_size=None, + device_type='cpu', + ) + return NativeFSDPStrategy(device_mesh=device_mesh) + + def build_multi_lora_model(model_dir): device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) model = MultiLoraTransformersModel( @@ -103,7 +115,8 @@ def build_multi_lora_model(model_dir): device_mesh=device_mesh, grad_scaler_config={}, ) - model.add_adapter_to_model('default', LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) + model.add_adapter_to_model('default', + LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) return model @@ -173,7 +186,7 @@ def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): ], ) def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, - expected_pattern): + expected_pattern): ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) _ensure_file_exists(ckpt_dir / missing_name) (ckpt_dir / missing_name).unlink() @@ -291,3 +304,81 @@ def test_read_training_progress_returns_metadata_without_mutating_optimizer_stat assert progress['consumed_train_samples'] == 12 assert restored.optimizer_group[''].cur_step == 0 assert restored.optimizer_group[''].gradient_accumulation_steps == 1 + + +def test_native_fsdp_strategy_requires_wrapped_optimizer_state(): + strategy = build_native_fsdp_strategy() + + assert strategy.needs_wrapped_optimizer_state() is True + + +def test_native_fsdp_strategy_save_optimizer_checkpoint_uses_full_state_dict(monkeypatch, tmp_path): + strategy = build_native_fsdp_strategy() + model = object() + optimizer = object() + optimizer_path = tmp_path / 'optimizer.pt' + captured = {} + + from torch.distributed.checkpoint import state_dict as checkpoint_state_dict + + def fake_get_optimizer_state_dict(model_arg, optimizer_arg, *, options=None): + captured['model'] = model_arg + captured['optimizer'] = optimizer_arg + captured['options'] = options + return {'state': {'step': 3}} + + def fake_save(obj, path): + captured['saved_obj'] = obj + captured['saved_path'] = path + + monkeypatch.setattr(checkpoint_state_dict, 'get_optimizer_state_dict', fake_get_optimizer_state_dict) + monkeypatch.setattr(torch, 'save', fake_save) + + strategy.save_optimizer_checkpoint(model, optimizer, str(optimizer_path)) + + assert captured['model'] is model + assert captured['optimizer'] is optimizer + assert captured['saved_obj'] == {'state': {'step': 3}} + assert captured['saved_path'] == str(optimizer_path) + assert captured['options'].full_state_dict is True + assert captured['options'].cpu_offload is True + assert captured['options'].broadcast_from_rank0 is False + + +def test_native_fsdp_strategy_load_optimizer_checkpoint_broadcasts_from_rank0(monkeypatch, tmp_path): + strategy = build_native_fsdp_strategy() + model = object() + optimizer = object() + optimizer_path = tmp_path / 'optimizer.pt' + optimizer_path.write_bytes(b'placeholder') + expected_state = {'state': {'step': 7}} + captured = {} + + from torch.distributed.checkpoint import state_dict as checkpoint_state_dict + + def fake_load(path, map_location=None, weights_only=None): + captured['loaded_path'] = path + captured['map_location'] = map_location + captured['weights_only'] = weights_only + return expected_state + + def fake_set_optimizer_state_dict(model_arg, optimizer_arg, optim_state_dict, *, options=None): + captured['model'] = model_arg + captured['optimizer'] = optimizer_arg + captured['optim_state_dict'] = optim_state_dict + captured['options'] = options + + monkeypatch.setattr(torch, 'load', fake_load) + monkeypatch.setattr(checkpoint_state_dict, 'set_optimizer_state_dict', fake_set_optimizer_state_dict) + + strategy.load_optimizer_checkpoint(model, optimizer, str(optimizer_path)) + + assert captured['loaded_path'] == str(optimizer_path) + assert captured['map_location'] == 'cpu' + assert captured['weights_only'] is True + assert captured['model'] is model + assert captured['optimizer'] is optimizer + assert captured['optim_state_dict'] == expected_state + assert captured['options'].full_state_dict is True + assert captured['options'].cpu_offload is False + assert captured['options'].broadcast_from_rank0 is True diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py index f3523913..cd0fcfde 100644 --- a/tests/server/model/test_twinkle_resume_routes.py +++ b/tests/server/model/test_twinkle_resume_routes.py @@ -1,13 +1,12 @@ import concurrent.futures import importlib.util import sys +from fastapi import FastAPI +from fastapi.testclient import TestClient from pathlib import Path from types import ModuleType from unittest.mock import Mock -from fastapi import FastAPI -from fastapi.testclient import TestClient - class _NoOpProcessPoolExecutor: @@ -86,7 +85,10 @@ def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monk response = client.post( '/twinkle/load_training_state', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': '' + }, ) assert response.status_code == 200 @@ -102,7 +104,10 @@ def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): response = client.post( '/twinkle/load_training_state', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': 'adapter-a'}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': 'adapter-a' + }, ) assert response.status_code == 200 @@ -117,7 +122,10 @@ def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(mon response = client.post( '/twinkle/load_training_state', - json={'name': 'local-checkpoint-dir', 'adapter_name': ''}, + json={ + 'name': 'local-checkpoint-dir', + 'adapter_name': '' + }, ) assert response.status_code == 200 @@ -133,7 +141,10 @@ def test_read_training_progress_route_returns_progress_and_calls_model(monkeypat response = client.post( '/twinkle/read_training_progress', - json={'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', 'adapter_name': ''}, + json={ + 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', + 'adapter_name': '' + }, ) assert response.status_code == 200 From 70ebe50c4ffc1559bf6f3cef0d71608e3fcc6b3f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 09:54:01 +0800 Subject: [PATCH 14/60] wip --- ...7-transformers-checkpoint-resume-design.md | 432 ------------------ .../transformers/test_checkpoint_resume.py | 384 ---------------- .../model/test_twinkle_resume_routes.py | 156 ------- 3 files changed, 972 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md delete mode 100644 tests/model/transformers/test_checkpoint_resume.py delete mode 100644 tests/server/model/test_twinkle_resume_routes.py diff --git a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md b/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md deleted file mode 100644 index 4f41c64f..00000000 --- a/docs/superpowers/specs/2026-03-27-transformers-checkpoint-resume-design.md +++ /dev/null @@ -1,432 +0,0 @@ -# Transformers Strict Resume Design - -## Summary - -This design adds real checkpoint resumption support for `TransformersModel` without introducing a new trainer class. - -The design supports both full-parameter training and LoRA training: - -- full-parameter training restores weights during model initialization -- LoRA training restores adapter weights through the existing load path -- both modes share the same training-state resume contract -- strict model-state resume does not silently fall back to weight-only loading when required state is missing - -Because Twinkle keeps the training loop explicit in user code, the design extends existing model, dataloader, server, and client interfaces rather than adding a central trainer abstraction. - -## Goals - -- Support true checkpoint resume for `TransformersModel` -- Support both full-parameter and LoRA training resume -- Restore optimizer state, scheduler state, scaler state, RNG state, and step counters -- Support dataset progress skipping for map-style datasets -- Expose Swift-like resume controls without adding a new trainer class -- Keep training-state save and load compatible with NPU (Ascend) environments -- Preserve existing weight-only loading and saving behavior - -## Non-Goals - -- Do not introduce a new `Trainer` class or resume manager class -- Do not guarantee exact sample-by-sample replay when retry-based sampling changes sample order -- Do not support exact data-progress resume for `IterableDataset` or streaming datasets -- Do not attempt to persist transient runtime state such as in-flight batch tensors, current loss tensors, or metric caches - -## User-Facing Resume Controls - -Resume behavior is controlled by existing training entrypoints through three new parameters: - -- `resume_from_checkpoint: Optional[str] = None` -- `resume_only_model: bool = False` -- `ignore_data_skip: bool = False` - -### Parameter semantics - -#### `resume_from_checkpoint` - -- Specifies the checkpoint directory or checkpoint path to resume from -- When unset, training starts normally from scratch -- When set, the training entrypoint reads the checkpoint and restores model state through existing model APIs - -#### `resume_only_model` - -- Defaults to `False` -- When `False`, resume restores full training state -- When `True`, resume restores only model weights - -#### `ignore_data_skip` - -- Only meaningful when `resume_from_checkpoint` is set and `resume_only_model=True` -- Defaults to `False` -- When `False`, the system still restores training progress metadata needed for data skipping and step/epoch continuation, but does not restore optimizer, scheduler, scaler, or RNG -- When `True`, the system restores only model weights and does not restore training progress or skip consumed data - -### Effective behavior matrix - -#### Case 1: `resume_from_checkpoint is None` - -- Start a new training run - -#### Case 2: `resume_from_checkpoint is not None` and `resume_only_model=False` - -- Restore model weights -- Restore optimizer state -- Restore scheduler state -- Restore scaler state -- Restore RNG state -- Restore step counters -- Attempt to skip already consumed training data -- If required model training state is missing, fail without fallback - -#### Case 3: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=False` - -- Restore model weights only -- Do not restore optimizer, scheduler, scaler, or RNG -- Restore step/progress metadata needed for data skipping -- Attempt to skip already consumed training data - -#### Case 4: `resume_from_checkpoint is not None` and `resume_only_model=True` and `ignore_data_skip=True` - -- Restore model weights only -- Do not restore optimizer, scheduler, scaler, RNG, step counters, or data progress -- Restart the training loop from step 0 with no skipping - -## Checkpoint Layout - -Existing weight layouts remain valid. New training-state files are added alongside current checkpoint contents. - -### Existing files preserved - -- full-model weights saved by `save_pretrained` -- LoRA weights saved as `adapter_model.safetensors` -- tokenizer artifacts -- `optimizer.pt` -- `scheduler.pt` - -### New training-state files - -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -### `trainer_state.json` contents - -`trainer_state.json` stores lightweight training metadata: - -- `checkpoint_version` -- `cur_step` -- `gradient_accumulation_steps` -- `consumed_train_samples` - -The design prefers storing `consumed_train_samples` as the canonical progress value and deriving batch skipping from it where needed. - -### `scaler.pt` contents - -- AMP scaler state dict -- optional scaler-related flags such as `scaler_has_nan` - -### `rng_state.pt` contents - -- Python `random` state -- NumPy RNG state -- PyTorch CPU RNG state -- CUDA RNG state - -## Accelerator Compatibility - -Training-state save and load must be accelerator-compatible, including Ascend NPU environments already supported by Twinkle. - -### Device-agnostic serialization - -Training-state files must use device-agnostic serialization: - -- optimizer, scheduler, scaler, and RNG payloads should be serialized in CPU-safe form -- JSON metadata stays in plain text files -- loading should first read state from CPU-safe files and then apply it to objects created on the current runtime device - -This avoids tying resume files to a specific device object layout during save. - -### RNG compatibility requirements - -RNG save and restore must branch by current accelerator backend: - -- CUDA runtime uses `torch.cuda` RNG APIs -- NPU runtime uses `torch.npu` RNG APIs -- CPU RNG and Python/NumPy RNG are always restored - -The implementation must not assume CUDA-only RNG helpers when saving or restoring training state. - -### Scope of compatibility - -The design requires resume support to work correctly in NPU environments. - -The design does not require cross-accelerator resume guarantees such as saving on GPU and resuming on NPU, or saving on NPU and resuming on GPU. The compatibility target is correct save and restore within the active supported accelerator backend. -## Restore Paths - -## Full-Parameter Training - -For full-parameter training, model weights are restored during initialization. - -### Full-parameter restore flow - -1. Construct `TransformersModel(model_id=ckpt_dir, ...)` -2. `__init__` uses `from_pretrained(ckpt_dir, ...)` to restore weights -3. Create optimizer, scheduler, and scaler objects -4. Call `load_training_state(ckpt_dir)` to restore training state -5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` - -This means full-parameter resume does not need a separate model-weight loading method after initialization. It only needs explicit training-state restoration. - -## LoRA Training - -For LoRA training, the existing adapter-weight load path remains in place. - -### LoRA restore flow - -1. Construct the model and adapter objects as today -2. Restore adapter weights through the existing `load()` path -3. Create optimizer, scheduler, and scaler objects -4. Call the same `load_training_state(ckpt_dir)` method to restore training state -5. If data skipping is enabled, rebuild dataloader with skip arguments derived from `trainer_state.json` - -## Unified training-state method - -The model layer gains a shared helper such as `load_training_state(ckpt_dir)`. - -This method restores: - -- `optimizer.pt` -- `scheduler.pt` -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -It assumes the corresponding optimizer, scheduler, and scaler objects have already been created before invocation. - -## Model Save and Load Semantics - -## Save behavior - -When saving with optimizer state enabled, the checkpoint includes: - -- weights in the existing full-model or LoRA format -- tokenizer artifacts -- `optimizer.pt` -- `scheduler.pt` -- `scaler.pt` -- `trainer_state.json` -- `rng_state.pt` - -When optimizer save is disabled, save remains weight-only and does not produce strict resume metadata. - -## Strict training-state restore - -Strict model-state resume restores: - -- optimizer state -- scheduler state -- scaler state -- RNG state -- `cur_step` -- `gradient_accumulation_steps` -- data-progress metadata - -### Failure behavior - -When strict training-state restore is requested, missing required model training state is an error: - -- missing `trainer_state.json` -> fail -- missing `optimizer.pt` when optimizer restore is required -> fail -- missing `scheduler.pt` when scheduler restore is required -> fail -- missing `scaler.pt` when scaler restore is required -> fail -- missing `rng_state.pt` when RNG restore is required -> fail -- malformed required fields -> fail - -This intentionally does not fall back to weight-only loading, to avoid falsely signaling successful strict resume. - -## Training Progress and Data Skipping - -Twinkle does not currently have a central trainer abstraction. Because of that, data skipping must be driven by existing training entrypoints and dataloader arguments. - -## Dataloader extensions - -Existing dataloader and sampler code is extended rather than replaced: - -- `twinkle.dataloader.DataLoader` -- `twinkle.dataloader.DeviceMeshSampler` -- retry-aware sampler flow - -The dataloader gains resume-oriented arguments: - -- `skip_samples: int = 0` -- optionally `skip_batches: int = 0` - -Map-style datasets use this progress to skip already consumed data before yielding new training batches. - -## Map-style dataset behavior - -For datasets with `__len__`, Twinkle attempts to skip previously consumed data using sampler or batch-sampler level skipping. - -Preferred behavior: - -- preserve existing sharding logic -- apply skip before data is yielded to the training loop -- keep the solution compatible with current `DeviceMeshSampler` wrapping - -## Iterable and streaming behavior - -`IterableDataset` and streaming datasets do not support exact progress skipping in this design. - -Behavior for these datasets: - -- restore model state according to the selected resume mode -- log a clear warning that consumed-data skipping is not supported -- continue training without skipping historical samples - -This is the only fallback allowed in the design. It applies only to dataset progress skipping, not to model-state resume. - -## Entry Point Integration - -No new trainer class is introduced. - -Resume parameters are threaded through existing training entrypoints: - -- direct local training loops using `TwinkleModel` / `TransformersModel` -- current client/server training flows that already support checkpoint save and load - -The practical integration model is: - -1. Parse or receive the three resume parameters -2. If `resume_from_checkpoint` is unset, construct dataloader normally -3. Construct model weights through the appropriate path - - full-parameter: restore through `__init__` - - LoRA: restore through existing adapter load logic -4. If `resume_only_model=False`, call `load_training_state(ckpt_dir)` -5. If `resume_only_model=True` and `ignore_data_skip=False`, read `trainer_state.json` for progress only -6. Recreate the dataloader with skip arguments applied when skipping is enabled - -This keeps the training loop explicit and compatible with current Twinkle examples. - -## Server and Client Behavior - -Server-side checkpoint save/load behavior should preserve current APIs while adding richer metadata. - -### Save path - -When server-side save endpoints request optimizer save: - -- save the model checkpoint as today -- save `optimizer.pt`, `scheduler.pt`, `scaler.pt`, `trainer_state.json`, and `rng_state.pt` -- persist checkpoint metadata through the existing checkpoint manager - -### Load path - -Current model load APIs remain the weight-loading trigger. - -The new resume parameters are primarily a training-entrypoint concern. They orchestrate whether to: - -- restore full training state -- restore weight only -- request data skipping - -The underlying server model APIs do not need a new trainer object to support this. - -## Compatibility Strategy - -### Existing checkpoints - -Existing checkpoints remain loadable in weight-only mode. - -Examples: - -- weight-only initialization for full-parameter checkpoints continues to work -- existing LoRA weight loading continues to work -- inference-only consumers remain unaffected - -### Old checkpoints under strict resume - -Old checkpoints that lack the new training-state files are not valid for strict resume. - -Expected behavior: - -- strict resume fails clearly -- weight-only load continues to work when requested explicitly - -### `resume_only_model=True` - -For `resume_only_model=True`, old checkpoints may still be usable if weight files are present. - -If data skipping is requested but no progress metadata exists, the entrypoint should fail clearly rather than silently train from the beginning while claiming resumed progress. - -## Risks and Constraints - -### RetrySampler interaction - -`RetrySampler` may retry or replace failed samples, including random backfill behavior at the tail of an epoch. - -Because of that: - -- progress skipping can preserve approximate data position -- exact sample-for-sample replay is not guaranteed when retry or backfill paths are exercised - -This limitation should be documented explicitly. - -### Dataset shape changes - -If dataset definition, slicing, filtering, or shuffle configuration changes between save and resume, data skipping semantics may become invalid. - -The user guidance should state that resume should be done with unchanged training parameters and unchanged dataset configuration. - -### Distributed consistency - -Skip logic must be compatible with current device-mesh sharding. The implementation should ensure skip is applied consistently before per-rank slicing causes divergence. - -## Testing Strategy - -Tests should cover: - -### Full-parameter training resume - -- initializing with `model_id=ckpt_dir` restores weights -- `load_training_state(ckpt_dir)` restores optimizer, scheduler, scaler, RNG, and step metadata - -### LoRA training resume - -- adapter-weight restore continues to work -- `load_training_state(ckpt_dir)` restores shared training state correctly - -### Strict restore failures - -- strict resume fails when required files are missing -- malformed state files fail clearly - -### Weight-only compatibility - -- legacy checkpoints still load in weight-only mode -- `resume_only_model=True` restores weights without optimizer, scheduler, scaler, or RNG - -### Data progress skipping - -- map-style datasets skip consumed data correctly -- skip behavior remains correct with device-mesh sharding -- iterable and streaming datasets emit warnings and continue without skipping - -## Implementation Outline - -1. Add model helpers for saving and loading split training-state files -2. Implement `load_training_state(ckpt_dir)` with shared behavior for full-parameter and LoRA training -3. Keep full-parameter weight restore in `__init__` -4. Keep LoRA weight restore in the existing adapter load path -5. Extend dataloader and sampler stack to support skip arguments for map-style datasets -6. Thread `resume_from_checkpoint`, `resume_only_model`, and `ignore_data_skip` through existing training entrypoints -7. Add warnings for unsupported iterable and streaming data skipping -8. Update docs and examples to show the new resume contract - -## User Guidance - -Recommended guidance text: - -- To resume training, keep other parameters unchanged and provide `resume_from_checkpoint` -- `resume_only_model=False` performs full resume -- `resume_only_model=True` restores only model weights -- `ignore_data_skip=True` disables progress restore and starts from step 0 -- Full-parameter checkpoints restore weights during model initialization and restore training state afterward -- Iterable and streaming datasets do not support consumed-data skipping and will resume without skipping data diff --git a/tests/model/transformers/test_checkpoint_resume.py b/tests/model/transformers/test_checkpoint_resume.py deleted file mode 100644 index 3e903d82..00000000 --- a/tests/model/transformers/test_checkpoint_resume.py +++ /dev/null @@ -1,384 +0,0 @@ -import concurrent.futures -import pytest -import sys -import torch -import types -import uuid -from pathlib import Path -from peft import LoraConfig -from tokenizers import Tokenizer -from tokenizers.models import WordLevel -from tokenizers.pre_tokenizers import Whitespace -from transformers import PreTrainedTokenizerFast, Qwen3Config, Qwen3ForCausalLM -from types import ModuleType - - -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor - -ROOT = Path(__file__).resolve().parents[3] -SRC = ROOT / 'src' -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -if 'zmq' not in sys.modules: - fake_zmq = ModuleType('zmq') - fake_zmq.Socket = object - fake_zmq.Context = object - fake_zmq.REP = 0 - fake_zmq.REQ = 1 - fake_zmq.IPV6 = 2 - fake_zmq.error = types.SimpleNamespace(Again=RuntimeError) - sys.modules['zmq'] = fake_zmq - -from twinkle.model.transformers import MultiLoraTransformersModel, TransformersModel -from twinkle.model.transformers.strategy import NativeFSDPStrategy - - -def build_tiny_tokenizer(): - vocab = {'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, 'hello': 4, 'world': 5} - backend = Tokenizer(WordLevel(vocab=vocab, unk_token='[UNK]')) - backend.pre_tokenizer = Whitespace() - tokenizer = PreTrainedTokenizerFast( - tokenizer_object=backend, - pad_token='[PAD]', - bos_token='[BOS]', - eos_token='[EOS]', - unk_token='[UNK]', - ) - return tokenizer - - -@pytest.fixture -def tmp_path(): - root = Path(__file__).parent / 'tmp_test_checkpoints' - root.mkdir(exist_ok=True) - path = root / uuid.uuid4().hex - path.mkdir() - return path - - -@pytest.fixture -def tiny_local_model_dir(tmp_path): - model_dir = tmp_path / 'tiny-qwen3' - model_dir.mkdir() - - tokenizer = build_tiny_tokenizer() - tokenizer.save_pretrained(model_dir) - - config = Qwen3Config( - vocab_size=tokenizer.vocab_size, - num_hidden_layers=1, - num_attention_heads=1, - hidden_size=16, - max_position_embeddings=32, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - ) - Qwen3ForCausalLM(config).save_pretrained(model_dir) - return model_dir - - -def build_full_param_model(model_dir): - return TransformersModel( - model_cls='Qwen3ForCausalLM', - model_id=str(model_dir), - mixed_precision='no', - grad_scaler_config={}, - ) - - -def build_native_fsdp_strategy(): - device_mesh = types.SimpleNamespace( - world_size=2, - ep_size=1, - ep_fsdp_size=None, - device_type='cpu', - ) - return NativeFSDPStrategy(device_mesh=device_mesh) - - -def build_multi_lora_model(model_dir): - device_mesh = types.SimpleNamespace(fsdp_world_size=0, data_world_size=1) - model = MultiLoraTransformersModel( - model_cls='Qwen3ForCausalLM', - model_id=str(model_dir), - mixed_precision='no', - device_mesh=device_mesh, - grad_scaler_config={}, - ) - model.add_adapter_to_model('default', - LoraConfig(r=2, lora_alpha=4, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'])) - return model - - -def prepare_full_param_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): - model = build_full_param_model(model_dir) - model.set_optimizer('AdamW', lr=1e-4) - model.set_lr_scheduler('LinearLR') - model.set_grad_scaler(device='cpu') - model.optimizer_group[''].cur_step = cur_step - return Path( - model.save( - name='full-resume', - output_dir=str(tmp_path), - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - )) - - -def prepare_lora_checkpoint(tmp_path, model_dir, cur_step=3, consumed_train_samples=6): - model = build_multi_lora_model(model_dir) - model.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - model.set_lr_scheduler('LinearLR', adapter_name='default') - model.set_grad_scaler(adapter_name='default', device='cpu') - model.optimizer_group['default'].cur_step = cur_step - return Path( - model.save( - name='lora-resume', - output_dir=str(tmp_path), - adapter_name='default', - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - )) - - -def _ensure_file_exists(path: Path): - if path.exists(): - return - if path.name == 'trainer_state.json': - path.write_text('{"cur_step": 3}', encoding='utf-8') - else: - path.write_bytes(b'placeholder') - - -def test_save_training_state_writes_split_files(tmp_path, tiny_local_model_dir): - model = build_full_param_model(tiny_local_model_dir) - model.set_optimizer('AdamW', lr=1e-4) - model.set_lr_scheduler('LinearLR') - model.set_grad_scaler(device='cpu') - model.optimizer_group[''].cur_step = 7 - - ckpt_dir = Path(model.save(name='resume-step', output_dir=str(tmp_path), save_optimizer=True)) - - assert (ckpt_dir / 'optimizer.pt').exists() - assert (ckpt_dir / 'scheduler.pt').exists() - assert (ckpt_dir / 'scaler.pt').exists() - assert (ckpt_dir / 'trainer_state.json').exists() - assert (ckpt_dir / 'rng_state.pt').exists() - - -@pytest.mark.parametrize( - 'missing_name, expected_pattern', - [ - ('optimizer.pt', 'optimizer.pt'), - ('scheduler.pt', 'scheduler.pt'), - ('rng_state.pt', 'rng_state.pt'), - ('trainer_state.json', 'trainer_state.json'), - ], -) -def test_load_training_state_fails_when_required_file_missing(tmp_path, tiny_local_model_dir, missing_name, - expected_pattern): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) - _ensure_file_exists(ckpt_dir / missing_name) - (ckpt_dir / missing_name).unlink() - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - with pytest.raises(FileNotFoundError, match=expected_pattern): - restored.load_training_state(str(ckpt_dir)) - - -def test_load_training_state_allows_missing_scaler_file(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=8, consumed_train_samples=16) - (ckpt_dir / 'scaler.pt').unlink() - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - - trainer_state = restored.load_training_state(str(ckpt_dir)) - - assert trainer_state['cur_step'] == 8 - assert trainer_state['consumed_train_samples'] == 16 - assert restored.optimizer_group[''].cur_step == 8 - - -def test_load_training_state_fails_for_malformed_trainer_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir) - (ckpt_dir / 'trainer_state.json').write_text('{"cur_step": 3}', encoding='utf-8') - - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - with pytest.raises((KeyError, ValueError), match='gradient_accumulation_steps|consumed_train_samples'): - restored.load_training_state(str(ckpt_dir)) - - -def test_full_parameter_resume_restores_training_state_after_init(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=9, consumed_train_samples=18) - saved = build_full_param_model(tiny_local_model_dir) - saved.set_optimizer('AdamW', lr=1e-4) - saved.set_lr_scheduler('LinearLR') - saved.set_grad_scaler(device='cpu') - saved.optimizer_group[''].cur_step = 9 - saved.optimizer_group[''].gradient_accumulation_steps = 4 - ckpt_dir = Path( - saved.save( - name='full-resume', - output_dir=str(tmp_path), - save_optimizer=True, - consumed_train_samples=18, - )) - restored = build_full_param_model(ckpt_dir) - restored.set_optimizer('AdamW', lr=1e-4) - restored.set_lr_scheduler('LinearLR') - restored.set_grad_scaler(device='cpu') - - trainer_state = restored.load_training_state(str(ckpt_dir)) - - assert trainer_state['cur_step'] == 9 - assert trainer_state['gradient_accumulation_steps'] == 4 - assert trainer_state['consumed_train_samples'] == 18 - assert restored.optimizer_group[''].cur_step == 9 - assert restored.optimizer_group[''].gradient_accumulation_steps == 4 - - -def test_lora_load_does_not_restore_training_state_without_explicit_resume(tmp_path, tiny_local_model_dir): - saved = build_multi_lora_model(tiny_local_model_dir) - saved.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - saved.set_lr_scheduler('LinearLR', adapter_name='default') - saved.set_grad_scaler(adapter_name='default', device='cpu') - saved.optimizer_group['default'].cur_step = 5 - saved.optimizer_group['default'].gradient_accumulation_steps = 3 - ckpt_dir = Path( - saved.save( - name='lora-resume', - output_dir=str(tmp_path), - adapter_name='default', - save_optimizer=True, - consumed_train_samples=10, - )) - restored = build_multi_lora_model(tiny_local_model_dir) - - assert restored.optimizer_group['default'].cur_step == 0 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 - - restored.load(name=ckpt_dir.name, output_dir=str(ckpt_dir.parent), adapter_name='default') - assert restored.optimizer_group['default'].cur_step == 0 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 1 - - restored.set_optimizer('AdamW', adapter_name='default', lr=1e-4) - restored.set_lr_scheduler('LinearLR', adapter_name='default') - restored.set_grad_scaler(adapter_name='default', device='cpu') - - trainer_state = restored.load_training_state(str(ckpt_dir), adapter_name='default') - - assert trainer_state['cur_step'] == 5 - assert trainer_state['gradient_accumulation_steps'] == 3 - assert trainer_state['consumed_train_samples'] == 10 - assert restored.optimizer_group['default'].cur_step == 5 - assert restored.optimizer_group['default'].gradient_accumulation_steps == 3 - - -def test_read_training_progress_returns_metadata_without_mutating_optimizer_state(tmp_path, tiny_local_model_dir): - ckpt_dir = prepare_full_param_checkpoint(tmp_path, tiny_local_model_dir, cur_step=6, consumed_train_samples=12) - restored = build_full_param_model(ckpt_dir) - - progress = restored.read_training_progress(str(ckpt_dir)) - - assert progress['cur_step'] == 6 - assert progress['consumed_train_samples'] == 12 - assert restored.optimizer_group[''].cur_step == 0 - assert restored.optimizer_group[''].gradient_accumulation_steps == 1 - - -def test_native_fsdp_strategy_requires_wrapped_optimizer_state(): - strategy = build_native_fsdp_strategy() - - assert strategy.needs_wrapped_optimizer_state() is True - - -def test_native_fsdp_strategy_save_optimizer_checkpoint_uses_full_state_dict(monkeypatch, tmp_path): - strategy = build_native_fsdp_strategy() - model = object() - optimizer = object() - optimizer_path = tmp_path / 'optimizer.pt' - captured = {} - - from torch.distributed.checkpoint import state_dict as checkpoint_state_dict - - def fake_get_optimizer_state_dict(model_arg, optimizer_arg, *, options=None): - captured['model'] = model_arg - captured['optimizer'] = optimizer_arg - captured['options'] = options - return {'state': {'step': 3}} - - def fake_save(obj, path): - captured['saved_obj'] = obj - captured['saved_path'] = path - - monkeypatch.setattr(checkpoint_state_dict, 'get_optimizer_state_dict', fake_get_optimizer_state_dict) - monkeypatch.setattr(torch, 'save', fake_save) - - strategy.save_optimizer_checkpoint(model, optimizer, str(optimizer_path)) - - assert captured['model'] is model - assert captured['optimizer'] is optimizer - assert captured['saved_obj'] == {'state': {'step': 3}} - assert captured['saved_path'] == str(optimizer_path) - assert captured['options'].full_state_dict is True - assert captured['options'].cpu_offload is True - assert captured['options'].broadcast_from_rank0 is False - - -def test_native_fsdp_strategy_load_optimizer_checkpoint_broadcasts_from_rank0(monkeypatch, tmp_path): - strategy = build_native_fsdp_strategy() - model = object() - optimizer = object() - optimizer_path = tmp_path / 'optimizer.pt' - optimizer_path.write_bytes(b'placeholder') - expected_state = {'state': {'step': 7}} - captured = {} - - from torch.distributed.checkpoint import state_dict as checkpoint_state_dict - - def fake_load(path, map_location=None, weights_only=None): - captured['loaded_path'] = path - captured['map_location'] = map_location - captured['weights_only'] = weights_only - return expected_state - - def fake_set_optimizer_state_dict(model_arg, optimizer_arg, optim_state_dict, *, options=None): - captured['model'] = model_arg - captured['optimizer'] = optimizer_arg - captured['optim_state_dict'] = optim_state_dict - captured['options'] = options - - monkeypatch.setattr(torch, 'load', fake_load) - monkeypatch.setattr(checkpoint_state_dict, 'set_optimizer_state_dict', fake_set_optimizer_state_dict) - - strategy.load_optimizer_checkpoint(model, optimizer, str(optimizer_path)) - - assert captured['loaded_path'] == str(optimizer_path) - assert captured['map_location'] == 'cpu' - assert captured['weights_only'] is True - assert captured['model'] is model - assert captured['optimizer'] is optimizer - assert captured['optim_state_dict'] == expected_state - assert captured['options'].full_state_dict is True - assert captured['options'].cpu_offload is False - assert captured['options'].broadcast_from_rank0 is True diff --git a/tests/server/model/test_twinkle_resume_routes.py b/tests/server/model/test_twinkle_resume_routes.py deleted file mode 100644 index cd0fcfde..00000000 --- a/tests/server/model/test_twinkle_resume_routes.py +++ /dev/null @@ -1,156 +0,0 @@ -import concurrent.futures -import importlib.util -import sys -from fastapi import FastAPI -from fastapi.testclient import TestClient -from pathlib import Path -from types import ModuleType -from unittest.mock import Mock - - -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor - -ROOT = Path(__file__).resolve().parents[3] -SRC = ROOT / 'src' -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) - -if 'twinkle.server.common.checkpoint_factory' not in sys.modules: - fake_checkpoint_factory = ModuleType('twinkle.server.common.checkpoint_factory') - fake_checkpoint_factory.create_checkpoint_manager = lambda *args, **kwargs: None - fake_checkpoint_factory.create_training_run_manager = lambda *args, **kwargs: None - sys.modules['twinkle.server.common.checkpoint_factory'] = fake_checkpoint_factory - -from twinkle_client.types.checkpoint import ResolvedLoadPath - -_HANDLERS_PATH = SRC / 'twinkle' / 'server' / 'model' / 'twinkle_handlers.py' -_HANDLERS_SPEC = importlib.util.spec_from_file_location('twinkle_resume_test_handlers', _HANDLERS_PATH) -handlers = importlib.util.module_from_spec(_HANDLERS_SPEC) -sys.modules[_HANDLERS_SPEC.name] = handlers -_HANDLERS_SPEC.loader.exec_module(handlers) - - -class _FakeCheckpointManager: - - def __init__(self, checkpoint_dir='./resolved/weights'): - self._checkpoint_dir = checkpoint_dir - - def resolve_load_path(self, path: str) -> ResolvedLoadPath: - return ResolvedLoadPath( - checkpoint_name='ckpt-1', - checkpoint_dir=self._checkpoint_dir, - is_twinkle_path=True, - training_run_id='run-1', - checkpoint_id='weights/ckpt-1', - ) - - -class _FakeModelManagement: - - def __init__(self): - self.model = Mock() - - async def _on_request_start(self, request): - request.state.request_id = 'req-1' - return 'token-1' - - def assert_resource_exists(self, adapter_name): - return None - - async def schedule_task_and_wait(self, task, task_type=''): - return await task() - - -def _build_test_client(monkeypatch, checkpoint_manager=None): - management = _FakeModelManagement() - checkpoint_manager = checkpoint_manager or _FakeCheckpointManager() - monkeypatch.setattr(handlers, 'create_checkpoint_manager', lambda token, client_type='twinkle': checkpoint_manager) - - app = FastAPI() - handlers._register_twinkle_routes(app, lambda: management) - return TestClient(app), management - - -def test_load_training_state_route_resolves_checkpoint_path_and_calls_model(monkeypatch): - client, management = _build_test_client(monkeypatch) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name=None, - ) - management.model.read_training_progress.assert_not_called() - - -def test_load_training_state_route_prefixes_non_empty_adapter_name(monkeypatch): - client, management = _build_test_client(monkeypatch) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': 'adapter-a' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name='req-1-adapter-a', - ) - - -def test_load_training_state_route_uses_raw_name_when_checkpoint_dir_missing(monkeypatch): - client, management = _build_test_client(monkeypatch, checkpoint_manager=_FakeCheckpointManager(checkpoint_dir=None)) - - response = client.post( - '/twinkle/load_training_state', - json={ - 'name': 'local-checkpoint-dir', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - management.model.load_training_state.assert_called_once_with( - 'local-checkpoint-dir', - adapter_name=None, - ) - - -def test_read_training_progress_route_returns_progress_and_calls_model(monkeypatch): - client, management = _build_test_client(monkeypatch) - management.model.read_training_progress.return_value = {'cur_step': 6, 'consumed_train_samples': 12} - - response = client.post( - '/twinkle/read_training_progress', - json={ - 'name': 'twinkle://training_runs/run-1/checkpoints/weights/ckpt-1', - 'adapter_name': '' - }, - ) - - assert response.status_code == 200 - assert response.json()['result']['consumed_train_samples'] == 12 - management.model.read_training_progress.assert_called_once_with( - 'resolved/weights/ckpt-1', - adapter_name=None, - ) - management.model.load_training_state.assert_not_called() From 483778d7482eb506138343f308fe25e1d1c22b04 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 10:29:29 +0800 Subject: [PATCH 15/60] wip --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index ae4d2cb3..58f495d4 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,6 @@ wheels/ /temp MANIFEST .locks/ -tmp_test_checkpoints/ # PyInstaller # Usually these files are written by a python script from a template @@ -144,7 +143,6 @@ images /custom/ megatron_output/ .qoder -.worktrees/ # Pytorch *.pth From 039789b72f695dd9f20fbf9a396d5bdf44f5e835 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 14:32:32 +0800 Subject: [PATCH 16/60] wip --- client_tools/client_generator.py | 3 ++- .../client/twinkle/self_host/self_congnition.py | 15 ++++++++++++--- src/twinkle/server/model/twinkle_handlers.py | 9 +++++---- src/twinkle_client/dataloader/dataloader.py | 13 +++++++++++++ .../model/multi_lora_transformers.py | 3 ++- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index f724c7c1..3dc99eba 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -618,13 +618,14 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> None: + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" response = http_post( url=f'{self.server_url}/load_training_state', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() + return TrainingProgressResponse(**response.json()).result def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: """Read progress-only checkpoint metadata for resume-only-model flows.""" diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index e31daaba..3975f2a8 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -99,9 +99,13 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint + consumed_train_samples = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) @@ -114,6 +118,7 @@ def train(): # Step model.clip_grad_and_step() + consumed_train_samples += len(batch) # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -131,7 +136,11 @@ def train(): logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 9: Upload the checkpoint to ModelScope Hub diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index b48b4e12..250fdc57 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -349,12 +349,12 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) - @app.post('/twinkle/load_training_state') + @app.post('/twinkle/load_training_state', response_model=types.TrainingProgressResponse) async def load_training_state( request: Request, body: types.LoadTrainingStateRequest, self: ModelManagement = Depends(self_fn), - ) -> None: + ) -> types.TrainingProgressResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -366,13 +366,14 @@ async def _task(): checkpoint_dir = ( Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() if resolved.checkpoint_dir else body.name) - self.model.load_training_state( + ret = self.model.load_training_state( checkpoint_dir, adapter_name=adapter_name, **extra_kwargs, ) + return {'result': ret} - await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) + return await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) async def read_training_progress( diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 0a067ddd..f6a24fe4 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -82,4 +82,17 @@ def __next__(self): ) response.raise_for_status() return response.json()["result"] + + + def skip_consumed_samples(self, consumed_train_samples: int): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'skip_consumed_samples', + **{'consumed_train_samples': consumed_train_samples}, + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 2a8afdce..37eac765 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -189,13 +189,14 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> None: + def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" response = http_post( url=f'{self.server_url}/load_training_state', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() + return TrainingProgressResponse(**response.json()).result def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: """Read progress-only checkpoint metadata for resume-only-model flows.""" From 54de1a44b9f36316f2c6198b3e13f0ae63a7dc7a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:43:57 +0800 Subject: [PATCH 17/60] wip --- src/twinkle/model/transformers/transformers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index fca3909c..acda5da0 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -39,11 +39,15 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template +from twinkle.utils.logger import get_logger from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +logger = get_logger() + + @dataclass class OptimizerGroup(BaseOptimizerGroup): """Optimizer group for Transformers training.""" @@ -983,8 +987,10 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): if strict and not os.path.exists(optimizer_path): raise FileNotFoundError(optimizer_path) - if strict and not os.path.exists(scheduler_path): - raise FileNotFoundError(scheduler_path) + if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): + logger.warning( + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', + ) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: @@ -1062,7 +1068,6 @@ def load_training_state(self, checkpoint_dir, **kwargs): required_paths = { 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), - 'scheduler': os.path.join(checkpoint_dir, 'scheduler.pt'), 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), } for path in required_paths.values(): From 920ab869a7a96e065af3412fedc408d15835dd24 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:45:28 +0800 Subject: [PATCH 18/60] wip --- .../twinkle/self_host/self_congnition.py | 9 +++-- cookbook/transformers/resume_utils.py | 40 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 cookbook/transformers/resume_utils.py diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 3975f2a8..967cddbe 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -100,25 +100,28 @@ def train(): # Step 6: Optionally resume from a previous checkpoint consumed_train_samples = 0 + global_step = 0 if resume_path: logger.info(f'Resuming model weights from {resume_path}') model.load(resume_path) trainer_state = model.load_training_state(resume_path) dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for _, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() consumed_train_samples += len(batch) + global_step += 1 # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -130,10 +133,10 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - if step % 2 == 0: + if global_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint twinkle_path = model.save( diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py new file mode 100644 index 00000000..86745d0a --- /dev/null +++ b/cookbook/transformers/resume_utils.py @@ -0,0 +1,40 @@ +from pathlib import Path +from typing import Any, Optional +from twinkle import get_logger + + +logger = get_logger() + + +def resume_from_checkpoint( + model: Any, + dataloader: Any, + checkpoint_path: Path, + *, + resume_only_model: bool, + ignore_data_skip: bool, + adapter_name: Optional[str] = None) -> int: + checkpoint_dir = str(checkpoint_path) + model_kwargs = {} + if adapter_name is not None: + model_kwargs['adapter_name'] = adapter_name + + if resume_only_model: + if ignore_data_skip: + logger.info('Resumed weights only and restarted progress from step 0.') + return 0 + progress = model.read_training_progress(checkpoint_dir, **model_kwargs) + dataloader.skip_consumed_samples(progress['consumed_train_samples']) + optimizer_group_name = adapter_name if adapter_name is not None else '' + model.optimizer_group[optimizer_group_name].cur_step = progress['cur_step'] + model.optimizer_group[optimizer_group_name].gradient_accumulation_steps = progress[ + 'gradient_accumulation_steps'] + consumed_train_samples = int(progress['consumed_train_samples']) + logger.info(f'Skipped {consumed_train_samples} consumed samples.') + return consumed_train_samples + + trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') + return consumed_train_samples From ffd630484a2e4d196bba9afd39a4a8f1ee10981f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:46:32 +0800 Subject: [PATCH 19/60] lint --- src/twinkle/model/transformers/transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index acda5da0..4e641217 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -39,11 +39,10 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template -from twinkle.utils.logger import get_logger from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm - +from twinkle.utils.logger import get_logger logger = get_logger() @@ -989,8 +988,7 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): raise FileNotFoundError(optimizer_path) if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): logger.warning( - f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', - ) + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', ) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: From 582bd41c5e1a059cad7f96d3d2b99dd87a04d0e9 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 31 Mar 2026 15:53:09 +0800 Subject: [PATCH 20/60] wip --- client_tools/client_generator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 3dc99eba..c9f43fab 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. +# Copyright (c) ModelScope Contributors. All rights reserved. import ast from pathlib import Path from typing import Dict, List, Set, Tuple @@ -869,5 +869,4 @@ def apply_patch(self, patch_cls: str, **kwargs) -> None: generate_samplers() print('\n' + '=' * 60) - print('\nAll client code generation complete!\n') - + print('\n✓ All client code generation complete!\n') From 9cb6106b2fc76fccd4c5d6f6cbab7219f948429f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 09:27:29 +0800 Subject: [PATCH 21/60] wip --- ...00\344\275\263\345\256\236\350\267\265.md" | 24 +++---------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index f0112042..ad78e28d 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -410,19 +410,9 @@ def train(): model.set_lr_scheduler('LinearLR') # 恢复训练(如有检查点) - resume_from_checkpoint = resume_path - resume_only_model = False - ignore_data_skip = False - if resume_from_checkpoint: - logger.info(f'Resuming training from {resume_from_checkpoint}') - model.load(name=resume_from_checkpoint) - - if not resume_only_model: - trainer_state = model.load_training_state(resume_from_checkpoint) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - elif not ignore_data_skip: - progress = model.read_training_progress(resume_from_checkpoint) - dataloader.skip_consumed_samples(progress['consumed_train_samples']) + if resume_path: + logger.info(f'Resuming training from {resume_path}') + model.load(resume_path, load_optimizer=True) logger.info(model.get_train_configs()) @@ -455,14 +445,6 @@ if __name__ == '__main__': - 支持断点续训、检查点管理 - 可动态切换 LoRA 适配器、损失函数、优化器等组件 -Resume 模式: - -- `resume_from_checkpoint=None`:开始新的训练任务。 -- `resume_only_model=False`:恢复权重、optimizer、scheduler、scaler、RNG 和进度元数据。 -- `resume_only_model=True` 且 `ignore_data_skip=False`:恢复权重,读取进度元数据,并跳过已消费样本。 -- `resume_only_model=True` 且 `ignore_data_skip=True`:只恢复权重,训练步数和数据进度从 0 开始。 -- `skip_consumed_samples(...)` 不适用于 iterable / streaming dataset。 - ### 3.2 Tinker Client:简洁即用 Tinker 是一个轻量级训练 API。Twinkle 对 Tinker 客户端提供完整支持,几行代码就能拉起训练。已有 Tinker 代码的项目可以直接迎移到 Twinkle 服务端。 From c0cf72e8089e2c40596b2d28d7ee7ea863ba613e Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:27:01 +0800 Subject: [PATCH 22/60] wip --- cookbook/transformers/fsdp2.py | 86 ++++++++++++++++++++------- cookbook/transformers/resume_utils.py | 19 ++++-- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 7b6bd2a8..47512629 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm @@ -8,21 +10,39 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, fsdp_size=2, dp=4 -device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +from resume_utils import resume_from_checkpoint logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +FSDP_SIZE = 2 +DP_SIZE = 4 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-4 +GRADIENT_ACCUMULATION_STEPS = 2 +LOG_INTERVAL = 20 +EVAL_INTERVAL = 40 + +OUTPUT_DIR = './output/fsdp2' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +# Construct a device_mesh +device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + def eval(model): # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(100))) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) for step, batch in tqdm(enumerate(dataloader)): model.forward_only(inputs=batch) model.calculate_loss() @@ -32,29 +52,41 @@ def eval(model): def train(): # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset dataset.encode() # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = TransformersModel(model_id=MODEL_ID) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + consumed_train_samples = 0 + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=checkpoint_path, + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name=ADAPTER_NAME, + ) + logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) @@ -67,18 +99,32 @@ def train(): model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + consumed_train_samples += BATCH_SIZE + cur_step = model.optimizer_group[ADAPTER_NAME].cur_step + if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: metrics = eval(model) logger.info(f'Eval metric: {metrics}') - metrics['step'] = step + metrics['step'] = cur_step if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') + model.save( + f'checkpoint-{cur_step}', + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + model.save( + 'last-checkpoint', + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) if __name__ == '__main__': diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py index 86745d0a..3d2d197b 100644 --- a/cookbook/transformers/resume_utils.py +++ b/cookbook/transformers/resume_utils.py @@ -15,24 +15,33 @@ def resume_from_checkpoint( ignore_data_skip: bool, adapter_name: Optional[str] = None) -> int: checkpoint_dir = str(checkpoint_path) + adapter_name = adapter_name or '' model_kwargs = {} - if adapter_name is not None: + if adapter_name != '': + # Load adapter checkpoint. model_kwargs['adapter_name'] = adapter_name + model.load( + name=checkpoint_path.name, + output_dir=str(checkpoint_path.parent), + **model_kwargs, + ) if resume_only_model: + # Only load model weights, optionally skip data. if ignore_data_skip: logger.info('Resumed weights only and restarted progress from step 0.') return 0 progress = model.read_training_progress(checkpoint_dir, **model_kwargs) + # Skip consumed samples in dataloader and move optimizer to the right step. dataloader.skip_consumed_samples(progress['consumed_train_samples']) - optimizer_group_name = adapter_name if adapter_name is not None else '' - model.optimizer_group[optimizer_group_name].cur_step = progress['cur_step'] - model.optimizer_group[optimizer_group_name].gradient_accumulation_steps = progress[ - 'gradient_accumulation_steps'] + model.optimizer_group[adapter_name].cur_step = progress['cur_step'] + model.optimizer_group[adapter_name].gradient_accumulation_steps = progress['gradient_accumulation_steps'] + consumed_train_samples = int(progress['consumed_train_samples']) logger.info(f'Skipped {consumed_train_samples} consumed samples.') return consumed_train_samples + # Load full training state, including model weights, optimizer states, and training progress. trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) From 505a75cbeb0564a9c914d85749161333405537b3 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:41:08 +0800 Subject: [PATCH 23/60] wip --- cookbook/transformers/fsdp2.py | 76 +++++++++++++-------------- cookbook/transformers/resume_utils.py | 24 +++++---- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 47512629..45dd8ac1 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -16,6 +16,9 @@ MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' FSDP_SIZE = 2 DP_SIZE = 4 BATCH_SIZE = 8 @@ -23,6 +26,8 @@ GRADIENT_ACCUMULATION_STEPS = 2 LOG_INTERVAL = 20 EVAL_INTERVAL = 40 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 OUTPUT_DIR = './output/fsdp2' RESUME_FROM_CHECKPOINT = None @@ -36,29 +41,34 @@ twinkle.initialize(mode='local', global_device_mesh=device_mesh) -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, consumed_train_samples: int): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() + dataset = build_dataset(TRAIN_SAMPLES) # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel @@ -68,7 +78,7 @@ def train(): lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` @@ -91,40 +101,30 @@ def train(): # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 + optimizer_group = model.optimizer_group[ADAPTER_NAME] + best_loss = float('inf') # lora: 8G * 8 # full: 18G * 8 - for step, batch in enumerate(dataloader): + for batch in dataloader: # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() consumed_train_samples += BATCH_SIZE - cur_step = model.optimizer_group[ADAPTER_NAME].cur_step + cur_step = optimizer_group.cur_step if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: - metrics = eval(model) + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = cur_step - if loss_metric > float(metrics['loss']): - model.save( - f'checkpoint-{cur_step}', - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - ) - loss_metric = float(metrics['loss']) - model.save( - 'last-checkpoint', - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - ) + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', consumed_train_samples) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', consumed_train_samples) if __name__ == '__main__': diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py index 3d2d197b..fdacf075 100644 --- a/cookbook/transformers/resume_utils.py +++ b/cookbook/transformers/resume_utils.py @@ -1,11 +1,18 @@ from pathlib import Path from typing import Any, Optional + from twinkle import get_logger logger = get_logger() +def _build_model_kwargs(adapter_name: str) -> dict: + if not adapter_name: + return {} + return {'adapter_name': adapter_name} + + def resume_from_checkpoint( model: Any, dataloader: Any, @@ -14,12 +21,11 @@ def resume_from_checkpoint( resume_only_model: bool, ignore_data_skip: bool, adapter_name: Optional[str] = None) -> int: - checkpoint_dir = str(checkpoint_path) adapter_name = adapter_name or '' - model_kwargs = {} - if adapter_name != '': + checkpoint_dir = str(checkpoint_path) + model_kwargs = _build_model_kwargs(adapter_name) + if model_kwargs: # Load adapter checkpoint. - model_kwargs['adapter_name'] = adapter_name model.load( name=checkpoint_path.name, output_dir=str(checkpoint_path.parent), @@ -33,17 +39,17 @@ def resume_from_checkpoint( return 0 progress = model.read_training_progress(checkpoint_dir, **model_kwargs) # Skip consumed samples in dataloader and move optimizer to the right step. - dataloader.skip_consumed_samples(progress['consumed_train_samples']) - model.optimizer_group[adapter_name].cur_step = progress['cur_step'] - model.optimizer_group[adapter_name].gradient_accumulation_steps = progress['gradient_accumulation_steps'] - consumed_train_samples = int(progress['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) + optimizer_group = model.optimizer_group[adapter_name] + optimizer_group.cur_step = progress['cur_step'] + optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps'] logger.info(f'Skipped {consumed_train_samples} consumed samples.') return consumed_train_samples # Load full training state, including model weights, optimizer states, and training progress. trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) consumed_train_samples = int(trainer_state['consumed_train_samples']) + dataloader.skip_consumed_samples(consumed_train_samples) logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') return consumed_train_samples From a222b5b169b67bbad270105ecf8f37fa0e631190 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 10:56:27 +0800 Subject: [PATCH 24/60] fix --- src/twinkle/dataloader/dataloader.py | 15 +++++++++++++++ src/twinkle/dataloader/retry_sampler.py | 12 +++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index e2ef57ce..0725ee47 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import os import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -56,6 +57,7 @@ def __init__(self, self._skip_samples = 0 self._base_batch_sampler = None self._base_sampler = None + self._retry_sampler_seed = self._resolve_retry_sampler_seed() self._set_work_init_fn() def _set_work_init_fn(self): @@ -65,6 +67,17 @@ def _set_work_init_fn(self): num_workers=num_workers, rank=self.device_mesh.data_rank if self.device_mesh else 0) + @staticmethod + def _resolve_retry_sampler_seed() -> int: + env_seed = os.environ.get('TWINKLE_SEED') + if env_seed is not None: + return int(env_seed) + try: + from twinkle.infra import _seed + return int(_seed) + except Exception: + return 42 + @remote_function() def __len__(self): self._lazy_init_dataloader() @@ -145,6 +158,7 @@ def _rebuild_sampler_stack(self): self._base_sampler, self.dataset, max_retries=self.max_retries, + seed=self._retry_sampler_seed, ) self.dataloader.batch_sampler = DeviceMeshSampler( batch_sampler, @@ -158,4 +172,5 @@ def _rebuild_sampler_stack(self): self.dataset, max_retries=self.max_retries, skip_samples=self._skip_samples, + seed=self._retry_sampler_seed, ) diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 43307b1a..27ef3819 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,11 +14,17 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20, skip_samples: int = 0): + def __init__(self, + original_sampler: Sampler, + dataset: Dataset, + max_retries=20, + skip_samples: int = 0, + seed: int = 42): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries self.skip_samples = skip_samples + self.seed = int(seed) def __iter__(self): emitted = 0 @@ -48,9 +54,9 @@ def __iter__(self): if emitted >= target_total: return - for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): + for idx in np.random.RandomState(self.seed).permutation(len(self.dataset)).tolist(): if emitted >= target_total: - raise StopIteration + return for _ in range(self.max_retries): try: # Skip None values and raises From 7499e00f500507b9aacd59432aaa36037abf0576 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 11:09:50 +0800 Subject: [PATCH 25/60] wip --- src/twinkle/dataloader/retry_sampler.py | 2 +- src/twinkle/model/transformers/transformers.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 27ef3819..4d8c92e0 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -49,7 +49,7 @@ def __iter__(self): traceback.print_exc() continue else: - raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') + raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.') if emitted >= target_total: return diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 4e641217..060990ae 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -996,11 +996,11 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): if hasattr(self.strategy, 'load_optimizer_checkpoint'): self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) else: - state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) + state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True) optimizer_config.optimizer.load_state_dict(state_dict) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=False) + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=True) optimizer_config.lr_scheduler.load_state_dict(state_dict) def _load_scaler_state(self, scaler_path, **kwargs): @@ -1009,7 +1009,7 @@ def _load_scaler_state(self, scaler_path, **kwargs): if optimizer_config.scaler is None: raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') - scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=False) + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=True) optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) From cd0b09411ed3cf902ca4662b30a204d783709516 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 14:40:17 +0800 Subject: [PATCH 26/60] doc --- .../Components/Model/TransformersModel.md | 13 ++++ docs/source_en/Usage Guide/Quick-Start.md | 67 +++++++++++++++++++ .../Server and Client/Twinkle-Client.md | 49 +++++++++----- ...53\351\200\237\345\274\200\345\247\213.md" | 66 ++++++++++++++++++ ...le\345\256\242\346\210\267\347\253\257.md" | 49 +++++++++----- .../TransformersModel.md" | 13 ++++ 6 files changed, 221 insertions(+), 36 deletions(-) diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index f10b0351..ba4ba1b4 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## Checkpoint and Resume + +`TransformersModel.save()` can save either weights only or a resumable training checkpoint. + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. +- `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. +- `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. +- `model.load_training_state(checkpoint_dir, ...)` restores optimizer-related state and returns the training progress dictionary. + +For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `load_training_state(...)` to restore optimizer state and training progress. + +For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 6a05a53f..d8b72ace 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -230,6 +230,71 @@ When running, you need to launch training like this: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### Resume from Checkpoint + +The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py` together with `cookbook/transformers/resume_utils.py`. + +When saving a checkpoint intended for resumption, save both model weights and training progress: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. + +To resume training, restore the checkpoint before entering the main loop: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +This helper provides two common resume modes: + +- Full resume: restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. +- Weights-only resume: restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. + +When `RESUME_ONLY_MODEL=True`, `IGNORE_DATA_SKIP=False` still skips already consumed samples based on `trainer_state.json`. If you want to reload weights but restart the dataset from the beginning, set `IGNORE_DATA_SKIP=True`. + +The flow above is intended for LoRA / adapter training. For full-parameter training, restore model weights by passing the checkpoint path as `model_id` when constructing `TransformersModel`, instead of calling `model.load(...)`. For example: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +In other words: + +- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.load(...)` or `resume_from_checkpoint(...)`. +- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `load_training_state(...)` to restore optimizer state and training progress. + ### Ray Training [Ray](https://github.com/ray-project/ray) is a commonly used scheduling middleware framework for multi-machine model training and inference scenarios. It provides additional optimizations for multi-model, multi-device execution and resource management, and supports integration with Kubernetes systems for production deployment. These characteristics make it particularly suitable for complex training scenarios such as RL and GKD. @@ -413,6 +478,8 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. +Checkpoint resumption is also supported in client-server training. The recommended flow is to restore weights with `model.load(resume_path)`, then restore optimizer and progress metadata with `model.load_training_state(resume_path)`, and finally call `dataloader.skip_consumed_samples(...)`. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. + Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: ```shell diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 66d98eec..e373e30c 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: Training loop -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # Forward propagation + backward propagation - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # Gradient clipping - model.clip_grad_norm(1.0) - - # Optimizer update - model.step() - - # Zero gradients - model.zero_grad() - - # Learning rate scheduling - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: Save checkpoint -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: Upload to ModelScope Hub (optional) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +For checkpoint resumption, the recommended client-side flow is: + +1. Query the server for an existing checkpoint path with `client.list_checkpoints(...)` or `client.get_latest_checkpoint_path(...)`. +2. Call `model.load(resume_path)` to restore adapter weights. +3. Call `model.load_training_state(resume_path)` to restore optimizer, scheduler, RNG, and progress metadata. +4. Call `dataloader.skip_consumed_samples(...)` with `consumed_train_samples` from the returned trainer state. + +This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_congnition.py`. + ## Differences with Megatron Backend When using the Megatron backend, the main differences in client code: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 0b8e386a..a7d98732 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -231,6 +231,71 @@ if __name__ == '__main__': CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### 断点续训 + +上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 + +如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: + +```python +consumed_train_samples = 0 + +def save_checkpoint(model, checkpoint_name): + model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name='default', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, + ) +``` + +其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。 + +恢复训练时,建议在进入主训练循环之前先加载 checkpoint: + +```python +from pathlib import Path + +from resume_utils import resume_from_checkpoint + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +consumed_train_samples = 0 +if RESUME_FROM_CHECKPOINT: + consumed_train_samples = resume_from_checkpoint( + model=model, + dataloader=dataloader, + checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), + resume_only_model=RESUME_ONLY_MODEL, + ignore_data_skip=IGNORE_DATA_SKIP, + adapter_name='default', + ) +``` + +这个辅助函数覆盖了两种常见恢复模式: + +- 完整续训:恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 +- 仅恢复权重:只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 + +当 `RESUME_ONLY_MODEL=True` 且 `IGNORE_DATA_SKIP=False` 时,仍会根据 `trainer_state.json` 跳过已训练过的数据;如果你只想加载权重、但从数据集开头重新训练,可以把 `IGNORE_DATA_SKIP=True`。 + +上面的恢复流程默认针对 LoRA / adapter 训练。对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 `model_id` 设为 checkpoint 路径,而不是再调用 `model.load(...)`。例如: + +```python +resume_path = './output/fsdp2/last-checkpoint' +model = TransformersModel(model_id=resume_path) +trainer_state = model.load_training_state(resume_path) +dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +``` + +也就是说: + +- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.load(...)` 或 `resume_from_checkpoint(...)` 恢复。 +- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + ### Ray训练 [Ray](https://github.com/ray-project/ray)是多机模型训练和推理场景中常用的调度中间件框架。它针对多模型、多设备的执行和资源管理进行了额外优化, @@ -412,6 +477,7 @@ python train.py ``` ### 远程训练 +client-server 训练场景同样支持断点续训。推荐流程是先通过 `model.load(resume_path)` 恢复权重,再通过 `model.load_training_state(resume_path)` 恢复优化器和训练进度元数据,最后调用 `dataloader.skip_consumed_samples(...)` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 Twinkle的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行lora训练,这样可以极大减小服务端部署成本。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index fd81ac1b..5d4bafe7 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -122,32 +122,36 @@ model.set_optimizer('AdamW', lr=1e-4) model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) +consumed_train_samples = 0 +global_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) # Step 6: 训练循环 -for step, batch in enumerate(dataloader): +for _, batch in enumerate(dataloader): # 前向传播 + 反向传播 - output = model.forward_backward(inputs=batch) + model.forward_backward(inputs=batch) - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') + # Step + model.clip_grad_and_step() + consumed_train_samples += len(batch) + global_step += 1 - # 梯度裁剪 - model.clip_grad_norm(1.0) - - # 优化器更新 - model.step() - - # 梯度清零 - model.zero_grad() - - # 学习率调度 - model.lr_step() + if global_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: 保存检查点 -twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) +twinkle_path = model.save( + name=f'step-{global_step}', + save_optimizer=True, + consumed_train_samples=consumed_train_samples, +) logger.info(f"Saved checkpoint: {twinkle_path}") # Step 8: 上传到 ModelScope Hub(可选) @@ -158,6 +162,15 @@ model.upload_to_hub( ) ``` +Twinkle Client 场景下,推荐的断点续训流程是: + +1. 先通过 `client.list_checkpoints(...)` 或 `client.get_latest_checkpoint_path(...)` 获取已有 checkpoint 路径。 +2. 调用 `model.load(resume_path)` 恢复 adapter 权重。 +3. 调用 `model.load_training_state(resume_path)` 恢复优化器、调度器、随机数状态和训练进度元数据。 +4. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.skip_consumed_samples(...)`,跳过已经训练过的数据。 + +完整示例可直接参考 `cookbook/client/twinkle/self_host/self_congnition.py`。 + ## Megatron 后端的差异 使用 Megatron 后端时,客户端代码的主要差异: diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index b7b9cf0f..bb494131 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -50,3 +50,16 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## 检查点保存与续训 + +`TransformersModel.save()` 既可以只保存权重,也可以保存可续训的训练检查点。 + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 会在保存权重的同时,保存优化器、学习率调度器、梯度缩放器、随机数状态以及 `trainer_state.json`。 +- `model.load(name, output_dir=..., adapter_name=...)` 用于恢复 LoRA / adapter 模型权重。 +- `model.read_training_progress(checkpoint_dir, ...)` 用于读取 checkpoint 中的训练进度元数据,例如 `cur_step`、`gradient_accumulation_steps` 和 `consumed_train_samples`。 +- `model.load_training_state(checkpoint_dir, ...)` 用于恢复优化器等训练状态,并返回训练进度字典。 + +对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 + +如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 From abf2c2f6b5f1dd027dc0772b883d4c430cee4370 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 17:08:07 +0800 Subject: [PATCH 27/60] wip --- .../model/transformers/transformers.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 060990ae..5bc2cb9f 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -43,6 +43,7 @@ from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm from twinkle.utils.logger import get_logger +from twinkle.utils.platforms import Platform logger = get_logger() @@ -1019,12 +1020,12 @@ def _get_training_rng_state(self): 'numpy_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if hasattr(torch, 'npu') and torch.npu.is_available(): - state['device_type'] = 'npu' - state['device_rng_state'] = torch.npu.get_rng_state() - elif torch.cuda.is_available(): - state['device_type'] = 'cuda' - state['device_rng_state'] = torch.cuda.get_rng_state_all() + + device_prefix = Platform.device_prefix() + device_module = getattr(torch, device_prefix, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + state['device_type'] = device_prefix + state['device_rng_state'] = device_module.get_rng_state() else: state['device_type'] = 'cpu' state['device_rng_state'] = None @@ -1038,10 +1039,10 @@ def _load_rng_state(self, rng_path): device_type = rng_state.get('device_type') device_rng_state = rng_state.get('device_rng_state') - if device_type == 'npu' and hasattr(torch, 'npu') and torch.npu.is_available() and device_rng_state is not None: - torch.npu.set_rng_state(device_rng_state) - elif device_type == 'cuda' and torch.cuda.is_available() and device_rng_state is not None: - torch.cuda.set_rng_state_all(device_rng_state) + if device_type != 'cpu' and device_rng_state is not None: + device_module = getattr(torch, device_type, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + device_module.set_rng_state(device_rng_state) @remote_function() def read_training_progress(self, checkpoint_dir, **kwargs): From 8bf7a6ad0e975326771664942e280e19a01f5833 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 1 Apr 2026 17:31:12 +0800 Subject: [PATCH 28/60] lint --- src/twinkle/model/transformers/transformers.py | 1 - tests/dataloader/test_dataloader.py | 14 +++++++------- tests/dataloader/test_sampler.py | 10 +++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5bc2cb9f..e2ee5243 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -43,7 +43,6 @@ from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm from twinkle.utils.logger import get_logger -from twinkle.utils.platforms import Platform logger = get_logger() diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 82a4f41b..edad0dd3 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -7,6 +7,13 @@ from torch.utils.data import Dataset as TorchDataset from torch.utils.data import IterableDataset as TorchIterableDataset +import twinkle +from twinkle import DeviceMesh +from twinkle.data_format import Message +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset +from twinkle.processor import InputProcessor + class _NoOpProcessPoolExecutor: @@ -19,13 +26,6 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -import twinkle -from twinkle import DeviceMesh -from twinkle.data_format import Message -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta, IterableDataset -from twinkle.processor import InputProcessor - twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index d5c97dbc..d90c8725 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -7,6 +7,11 @@ from torch.utils.data import Dataset as TorchDataset from torch.utils.data import RandomSampler, SequentialSampler +import twinkle +from twinkle import DeviceMesh +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta + class _NoOpProcessPoolExecutor: @@ -19,11 +24,6 @@ def submit(self, fn, *args, **kwargs): concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -import twinkle -from twinkle import DeviceMesh -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta - twinkle.initialize(mode='local') TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' From 9326e642604d4430184e6e4321acc50b47401287 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 16 Apr 2026 11:02:33 +0800 Subject: [PATCH 29/60] wip --- .../Server and Client/Twinkle-Client.md | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 900525a0..56afb872 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -148,25 +148,30 @@ logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for _, batch in enumerate(dataloader): + for step, batch in enumerate(dataloader): # Forward propagation + backward propagation model.forward_backward(inputs=batch) - # Step + # Gradient clipping + optimizer update (equivalent to clip_grad_norm / step / zero_grad / lr_step) model.clip_grad_and_step() - consumed_train_samples += len(batch) - global_step += 1 - if global_step % 2 == 0: - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') + if step % 2 == 0: + logger.info(f'Step {step // 2}, loss: {output}') + + # Gradient clipping + model.clip_grad_norm(1.0) + + # Optimizer update + model.step() + + # Zero gradients + model.zero_grad() + + # Learning rate scheduling + model.lr_step() # Step 7: Save checkpoint - twinkle_path = model.save( - name=f'twinkle-epoch-{epoch}', - save_optimizer=True, - consumed_train_samples=consumed_train_samples, - ) + twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 8: Upload to ModelScope Hub (optional) From 670f0c1b04ec9625339a869be369aaeb9752ccd6 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:42:22 +0800 Subject: [PATCH 30/60] feat: add resume_from_checkpoint abstract method to TwinkleModel base --- src/twinkle/model/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 596f3c32..bc132527 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -87,6 +87,10 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs) -> None: def get_state_dict(self, **kwargs) -> Dict[str, Any]: ... + @abstractmethod + def resume_from_checkpoint(self, checkpoint_dir: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + ... + @abstractmethod def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None: ... From 784730cc2c2dd083b279e3050414a2f333de9428 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:44:09 +0800 Subject: [PATCH 31/60] feat(dataloader): add resume_from_checkpoint wrapping skip_consumed_samples --- src/twinkle/dataloader/dataloader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index 0725ee47..268fad24 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -151,6 +151,10 @@ def skip_consumed_samples(self, consumed_train_samples: int) -> None: self._rebuild_sampler_stack() self.dataloader.__initialized = True + @remote_function() + def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + self.skip_consumed_samples(consumed_train_samples) + def _rebuild_sampler_stack(self): if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): batch_sampler = copy.copy(self._base_batch_sampler) From 3db38e9bf576951de9b53de4295d24864ded631b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:44:22 +0800 Subject: [PATCH 32/60] feat(transformers): replace load_training_state/read_training_progress with resume_from_checkpoint --- .../model/transformers/transformers.py | 58 ++++++++----------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 854e1329..fac596e3 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1060,44 +1060,34 @@ def _load_rng_state(self, rng_path): device_module.set_rng_state(device_rng_state) @remote_function() - def read_training_progress(self, checkpoint_dir, **kwargs): - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - if not os.path.exists(trainer_state_path): - raise FileNotFoundError(trainer_state_path) - - with open(trainer_state_path, encoding='utf-8') as f: - trainer_state = json.load(f) + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', '') - required_keys = {'checkpoint_version', 'cur_step', 'gradient_accumulation_steps', 'consumed_train_samples'} - missing_keys = required_keys - trainer_state.keys() - if missing_keys: - raise ValueError(f'Missing trainer_state keys: {sorted(missing_keys)}') - return trainer_state + has_adapter = ( + os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) + ) + if has_adapter: + self.load(checkpoint_dir, adapter_name=adapter_name) - @remote_function() - def load_training_state(self, checkpoint_dir, **kwargs): - adapter_name = kwargs.pop('adapter_name', _default_adapter_name) - optimizer_config = self.optimizer_group[adapter_name] + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) - required_paths = { - 'trainer_state': os.path.join(checkpoint_dir, 'trainer_state.json'), - 'optimizer': os.path.join(checkpoint_dir, 'optimizer.pt'), - 'rng': os.path.join(checkpoint_dir, 'rng_state.pt'), + if not resume_only_model: + adapter_name = adapter_name or self._get_default_group() + optimizer_config = self.optimizer_group[adapter_name] + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._load_scaler_state(checkpoint_dir) + self._load_rng_state(checkpoint_dir) + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], } - for path in required_paths.values(): - if not os.path.exists(path): - raise FileNotFoundError(path) - - trainer_state = self.read_training_progress(checkpoint_dir) - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name, strict=True) - scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') - if os.path.exists(scaler_path) and optimizer_config.scaler is not None: - self._load_scaler_state(scaler_path, adapter_name=adapter_name) - self._load_rng_state(required_paths['rng']) - - optimizer_config.cur_step = trainer_state['cur_step'] - optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] - return trainer_state @remote_function(collect='first') def get_state_dict(self, **kwargs): From 94679d58e413f4dee17eb7e6396191a89de7b853 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:44:22 +0800 Subject: [PATCH 33/60] feat(megatron): add resume_from_checkpoint and save trainer_state.json --- src/twinkle/model/megatron/megatron.py | 28 ++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index f087d3a6..86263ee4 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -813,6 +813,17 @@ def save(self, optimizer_config=optimizer_config, **kwargs, ) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + with open(state_path, 'w') as f: + json.dump(trainer_state, f, indent=2) # Final synchronization to ensure all ranks complete save. if dist.is_initialized(): @@ -866,6 +877,23 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): if dist.is_initialized(): dist.barrier() + @remote_function(dispatch='all') + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', self._get_default_group()) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, + adapter_name=adapter_name, **kwargs) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } + @staticmethod def _get_rng_state() -> 'ShardedObject': from megatron.core import parallel_state as mpu From 832ce87bb97e5cae6d5eb2807e621140293295ac Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:46:11 +0800 Subject: [PATCH 34/60] refactor(cookbook): use model.resume_from_checkpoint API --- .../twinkle/self_host/self_congnition.py | 11 ++-- cookbook/transformers/resume_utils.py | 58 +++++-------------- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 97965e9d..61a9224f 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -102,12 +102,11 @@ def train(): consumed_train_samples = 0 global_step = 0 if resume_path: - logger.info(f'Resuming model weights from {resume_path}') - model.load(resume_path) - trainer_state = model.load_training_state(resume_path) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - consumed_train_samples = int(trainer_state['consumed_train_samples']) - global_step = int(trainer_state['cur_step']) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + consumed_train_samples = int(progress['consumed_train_samples']) + global_step = int(progress['cur_step']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py index fdacf075..fd87d123 100644 --- a/cookbook/transformers/resume_utils.py +++ b/cookbook/transformers/resume_utils.py @@ -3,53 +3,25 @@ from twinkle import get_logger - logger = get_logger() -def _build_model_kwargs(adapter_name: str) -> dict: - if not adapter_name: - return {} - return {'adapter_name': adapter_name} - +def resume_from_checkpoint(model: Any, + dataloader: Any, + checkpoint_path: Path, + *, + resume_only_model: bool, + ignore_data_skip: bool, + adapter_name: Optional[str] = None) -> int: + kwargs = {} + if adapter_name: + kwargs['adapter_name'] = adapter_name -def resume_from_checkpoint( - model: Any, - dataloader: Any, - checkpoint_path: Path, - *, - resume_only_model: bool, - ignore_data_skip: bool, - adapter_name: Optional[str] = None) -> int: - adapter_name = adapter_name or '' - checkpoint_dir = str(checkpoint_path) - model_kwargs = _build_model_kwargs(adapter_name) - if model_kwargs: - # Load adapter checkpoint. - model.load( - name=checkpoint_path.name, - output_dir=str(checkpoint_path.parent), - **model_kwargs, - ) + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=resume_only_model, **kwargs) - if resume_only_model: - # Only load model weights, optionally skip data. - if ignore_data_skip: - logger.info('Resumed weights only and restarted progress from step 0.') - return 0 - progress = model.read_training_progress(checkpoint_dir, **model_kwargs) - # Skip consumed samples in dataloader and move optimizer to the right step. - consumed_train_samples = int(progress['consumed_train_samples']) - dataloader.skip_consumed_samples(consumed_train_samples) - optimizer_group = model.optimizer_group[adapter_name] - optimizer_group.cur_step = progress['cur_step'] - optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps'] - logger.info(f'Skipped {consumed_train_samples} consumed samples.') - return consumed_train_samples + consumed_train_samples = int(progress.get('consumed_train_samples', 0)) + if not ignore_data_skip and consumed_train_samples > 0: + dataloader.resume_from_checkpoint(consumed_train_samples) - # Load full training state, including model weights, optimizer states, and training progress. - trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) - consumed_train_samples = int(trainer_state['consumed_train_samples']) - dataloader.skip_consumed_samples(consumed_train_samples) - logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') return consumed_train_samples From e3a3cd6950a098fda2f35c0699e7fe8e86640376 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:46:17 +0800 Subject: [PATCH 35/60] feat(types): replace training state request types with ResumeFromCheckpointRequest --- src/twinkle_client/types/model.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index 0b6d7c08..18dae7b7 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -89,20 +89,11 @@ class Config: extra = 'allow' -class LoadTrainingStateRequest(BaseModel): - adapter_name: str +class ResumeFromCheckpointRequest(BaseModel): + """Request for /resume_from_checkpoint endpoint.""" name: str - - class Config: - extra = 'allow' - - -class ReadTrainingProgressRequest(BaseModel): - adapter_name: str - name: str - - class Config: - extra = 'allow' + adapter_name: str = '' + resume_only_model: bool = False class AddAdapterRequest(BaseModel): @@ -229,7 +220,7 @@ class SaveResponse(BaseModel): class TrainingProgressResponse(BaseModel): - """Response for /read_training_progress endpoint (returns progress metadata).""" + """Response for /resume_from_checkpoint endpoint.""" result: Dict[str, Any] From a3effab534f4443e1253657baf25eaaf539bb77b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:46:18 +0800 Subject: [PATCH 36/60] feat(server): replace training state endpoints with /resume_from_checkpoint --- src/twinkle/server/model/twinkle_handlers.py | 39 +++----------------- 1 file changed, 6 insertions(+), 33 deletions(-) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 250fdc57..86dd83ef 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -349,10 +349,10 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, task_type='load')) - @app.post('/twinkle/load_training_state', response_model=types.TrainingProgressResponse) - async def load_training_state( + @app.post('/twinkle/resume_from_checkpoint', response_model=types.TrainingProgressResponse) + async def resume_from_checkpoint( request: Request, - body: types.LoadTrainingStateRequest, + body: types.ResumeFromCheckpointRequest, self: ModelManagement = Depends(self_fn), ) -> types.TrainingProgressResponse: token = await self._on_request_start(request) @@ -360,46 +360,19 @@ async def load_training_state( async def _task(): self.assert_resource_exists(adapter_name) - extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) checkpoint_dir = ( Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() if resolved.checkpoint_dir else body.name) - ret = self.model.load_training_state( + ret = self.model.resume_from_checkpoint( checkpoint_dir, + resume_only_model=body.resume_only_model, adapter_name=adapter_name, - **extra_kwargs, - ) - return {'result': ret} - - return await run_task(self.schedule_task_and_wait(_task, task_type='load_training_state')) - - @app.post('/twinkle/read_training_progress', response_model=types.TrainingProgressResponse) - async def read_training_progress( - request: Request, - body: types.ReadTrainingProgressRequest, - self: ModelManagement = Depends(self_fn), - ) -> types.TrainingProgressResponse: - token = await self._on_request_start(request) - adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_resource_exists(adapter_name) - extra_kwargs = body.model_extra or {} - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - resolved = checkpoint_manager.resolve_load_path(body.name) - checkpoint_dir = ( - Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() - if resolved.checkpoint_dir else body.name) - ret = self.model.read_training_progress( - checkpoint_dir, - adapter_name=adapter_name, - **extra_kwargs, ) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='read_training_progress')) + return await run_task(self.schedule_task_and_wait(_task, task_type='resume')) @app.post('/twinkle/upload_to_hub') async def upload_to_hub( From 383336dc72aed120304a50d822742feea403dacd Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:46:23 +0800 Subject: [PATCH 37/60] feat(client): replace training state methods with resume_from_checkpoint --- client_tools/client_generator.py | 17 ++++------------- .../model/multi_lora_transformers.py | 17 ++++------------- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index c9f43fab..99d6fca1 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -618,20 +618,11 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: - """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: response = http_post( - url=f'{self.server_url}/load_training_state', - json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} - ) - response.raise_for_status() - return TrainingProgressResponse(**response.json()).result - - def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: - """Read progress-only checkpoint metadata for resume-only-model flows.""" - response = http_post( - url=f'{self.server_url}/read_training_progress', - json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + url=f'{self.server_url}/resume_from_checkpoint', + json_data={'name': name, 'adapter_name': self.adapter_name, + 'resume_only_model': resume_only_model, **kwargs} ) response.raise_for_status() return TrainingProgressResponse(**response.json()).result diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 37eac765..f7618f7d 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -189,20 +189,11 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() - def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: - """Load optimizer, scheduler, scaler, RNG, and progress metadata from a checkpoint.""" + def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: response = http_post( - url=f'{self.server_url}/load_training_state', - json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} - ) - response.raise_for_status() - return TrainingProgressResponse(**response.json()).result - - def read_training_progress(self, name: str, **kwargs) -> Dict[str, Any]: - """Read progress-only checkpoint metadata for resume-only-model flows.""" - response = http_post( - url=f'{self.server_url}/read_training_progress', - json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} + url=f'{self.server_url}/resume_from_checkpoint', + json_data={'name': name, 'adapter_name': self.adapter_name, + 'resume_only_model': resume_only_model, **kwargs} ) response.raise_for_status() return TrainingProgressResponse(**response.json()).result From 54a1db6fe846fa3e402b62e088f2e8820722d3d8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:48:45 +0800 Subject: [PATCH 38/60] docs: update checkpoint/resume documentation for unified API --- .../Components/Model/TransformersModel.md | 8 +- docs/source_en/Usage Guide/Quick-Start.md | 8 +- .../Server and Client/Twinkle-Client.md | 16 +- ...53\351\200\237\345\274\200\345\247\213.md" | 8 +- ...le\345\256\242\346\210\267\347\253\257.md" | 16 +- .../TransformersModel.md" | 10 +- .../plans/2026-04-21-unified-resume-api.md | 460 ++++++++++++++++++ .../2026-04-21-unified-resume-api-design.md | 181 +++++++ ...2026-04-21-unified-resume-api-design.zh.md | 181 +++++++ 9 files changed, 853 insertions(+), 35 deletions(-) create mode 100644 docs/superpowers/plans/2026-04-21-unified-resume-api.md create mode 100644 docs/superpowers/specs/2026-04-21-unified-resume-api-design.md create mode 100644 docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index ba4ba1b4..1caab30c 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -56,10 +56,10 @@ for data in dataloader: `TransformersModel.save()` can save either weights only or a resumable training checkpoint. - `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. -- `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. -- `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. -- `model.load_training_state(checkpoint_dir, ...)` restores optimizer-related state and returns the training progress dictionary. +- `model.resume_from_checkpoint(checkpoint_dir)` restores full training state (weights, optimizer, scheduler, scaler, RNG) and returns `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`. +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` loads weights only and returns progress metadata without restoring optimizer state. +- `dataloader.resume_from_checkpoint(consumed_train_samples)` skips already-consumed samples. -For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `load_training_state(...)` to restore optimizer state and training progress. +For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 98a4b8e2..28bb8a12 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -286,14 +286,14 @@ The flow above is intended for LoRA / adapter training. For full-parameter train ```python resume_path = './output/fsdp2/last-checkpoint' model = TransformersModel(model_id=resume_path) -trainer_state = model.load_training_state(resume_path) -dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +progress = model.resume_from_checkpoint(resume_path) +dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` In other words: -- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.load(...)` or `resume_from_checkpoint(...)`. -- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `load_training_state(...)` to restore optimizer state and training progress. +- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.resume_from_checkpoint(...)`. +- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. ### Ray Training diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 56afb872..b4a1e61e 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -136,12 +136,11 @@ model.set_optimizer('Adam', lr=1e-4) consumed_train_samples = 0 global_step = 0 if resume_path: - logger.info(f'Resuming model weights from {resume_path}') - model.load(resume_path) - trainer_state = model.load_training_state(resume_path) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - consumed_train_samples = int(trainer_state['consumed_train_samples']) - global_step = int(trainer_state['cur_step']) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + consumed_train_samples = int(progress['consumed_train_samples']) + global_step = int(progress['cur_step']) # Step 6: Training loop logger.info(model.get_train_configs().model_dump()) @@ -187,9 +186,8 @@ for epoch in range(3): For checkpoint resumption, the recommended client-side flow is: 1. Query the server for an existing checkpoint path with `client.list_checkpoints(...)` or `client.get_latest_checkpoint_path(...)`. -2. Call `model.load(resume_path)` to restore adapter weights. -3. Call `model.load_training_state(resume_path)` to restore optimizer, scheduler, RNG, and progress metadata. -4. Call `dataloader.skip_consumed_samples(...)` with `consumed_train_samples` from the returned trainer state. +2. Call `model.resume_from_checkpoint(resume_path)` to restore weights, optimizer, scheduler, RNG, and progress metadata. +3. Call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip already-consumed samples. This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_congnition.py`. diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index b2d72434..07eccc95 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -287,14 +287,14 @@ if RESUME_FROM_CHECKPOINT: ```python resume_path = './output/fsdp2/last-checkpoint' model = TransformersModel(model_id=resume_path) -trainer_state = model.load_training_state(resume_path) -dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) +progress = model.resume_from_checkpoint(resume_path) +dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` 也就是说: -- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.load(...)` 或 `resume_from_checkpoint(...)` 恢复。 -- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 +- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.resume_from_checkpoint(...)` 恢复。 +- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 ### Ray训练 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index f092e4de..9ee6f0fd 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -136,12 +136,11 @@ model.set_optimizer('Adam', lr=1e-4) consumed_train_samples = 0 global_step = 0 if resume_path: - logger.info(f'Resuming model weights from {resume_path}') - model.load(resume_path) - trainer_state = model.load_training_state(resume_path) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - consumed_train_samples = int(trainer_state['consumed_train_samples']) - global_step = int(trainer_state['cur_step']) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + consumed_train_samples = int(progress['consumed_train_samples']) + global_step = int(progress['cur_step']) # Step 6: 训练循环 logger.info(model.get_train_configs().model_dump()) @@ -177,9 +176,8 @@ for epoch in range(3): Twinkle Client 场景下,推荐的断点续训流程是: 1. 先通过 `client.list_checkpoints(...)` 或 `client.get_latest_checkpoint_path(...)` 获取已有 checkpoint 路径。 -2. 调用 `model.load(resume_path)` 恢复 adapter 权重。 -3. 调用 `model.load_training_state(resume_path)` 恢复优化器、调度器、随机数状态和训练进度元数据。 -4. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.skip_consumed_samples(...)`,跳过已经训练过的数据。 +2. 调用 `model.resume_from_checkpoint(resume_path)` 恢复权重、优化器、调度器、随机数状态和训练进度元数据。 +3. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.resume_from_checkpoint(...)`,跳过已经训练过的数据。 完整示例可直接参考 `cookbook/client/twinkle/self_host/self_congnition.py`。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index bb494131..f3816ade 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -55,11 +55,11 @@ for data in dataloader: `TransformersModel.save()` 既可以只保存权重,也可以保存可续训的训练检查点。 -- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 会在保存权重的同时,保存优化器、学习率调度器、梯度缩放器、随机数状态以及 `trainer_state.json`。 -- `model.load(name, output_dir=..., adapter_name=...)` 用于恢复 LoRA / adapter 模型权重。 -- `model.read_training_progress(checkpoint_dir, ...)` 用于读取 checkpoint 中的训练进度元数据,例如 `cur_step`、`gradient_accumulation_steps` 和 `consumed_train_samples`。 -- `model.load_training_state(checkpoint_dir, ...)` 用于恢复优化器等训练状态,并返回训练进度字典。 +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 保存权重、优化器、调度器、scaler、RNG 状态和 `trainer_state.json`。 +- `model.resume_from_checkpoint(checkpoint_dir)` 恢复完整训练状态(权重、优化器、调度器、scaler、RNG),返回 `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`。 +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` 仅加载权重并返回进度元数据,不恢复优化器状态。 +- `dataloader.resume_from_checkpoint(consumed_train_samples)` 跳过已消费的样本。 -对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `load_training_state(...)` 恢复优化器和训练进度。 +对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 diff --git a/docs/superpowers/plans/2026-04-21-unified-resume-api.md b/docs/superpowers/plans/2026-04-21-unified-resume-api.md new file mode 100644 index 00000000..5e0f9a5d --- /dev/null +++ b/docs/superpowers/plans/2026-04-21-unified-resume-api.md @@ -0,0 +1,460 @@ +# Unified `resume_from_checkpoint` API — Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace `load_training_state` / `read_training_progress` with a single `resume_from_checkpoint` method on both model backends and dataloader, so callers orchestrate with two lines instead of five. + +**Architecture:** Add `resume_from_checkpoint` as an abstract method on `TwinkleModel`. Each backend (Transformers, Megatron) implements it to restore its own state internally and return a common `{cur_step, consumed_train_samples, gradient_accumulation_steps}` dict. DataLoader gets a matching `resume_from_checkpoint` that wraps `skip_consumed_samples`. Server/client/cookbook/docs updated to match. + +**Tech Stack:** Python, PyTorch, FastAPI, Pydantic, PEFT, Megatron-Core + +**Spec:** `docs/superpowers/specs/2026-04-21-unified-resume-api-design.md` + +--- + +## Chunk 1: Core Model API + +### Task 1: Add `resume_from_checkpoint` to TwinkleModel base class + +**Files:** +- Modify: `src/twinkle/model/base.py:86-88` + +- [ ] **Step 1: Add abstract method after `get_state_dict`** + +In `src/twinkle/model/base.py`, insert after line 88 (`get_state_dict`): + +```python +@abstractmethod +def resume_from_checkpoint(self, checkpoint_dir: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + ... +``` + +- [ ] **Step 2: Verify no import changes needed** + +`Dict` and `Any` are already imported on line 4. No changes needed. + +- [ ] **Step 3: Commit** + +```bash +git add src/twinkle/model/base.py +git commit -m "feat: add resume_from_checkpoint abstract method to TwinkleModel base" +``` + +--- + +### Task 2: Implement `resume_from_checkpoint` in TransformersModel + +**Files:** +- Modify: `src/twinkle/model/transformers/transformers.py:1063-1100` + +- [ ] **Step 1: Delete `read_training_progress` method (lines 1063-1075)** + +Remove the entire `read_training_progress` method. + +- [ ] **Step 2: Delete `load_training_state` method (lines 1078-1100)** + +Remove the entire `load_training_state` method. + +- [ ] **Step 3: Add `resume_from_checkpoint` method** + +Insert at the same location where the deleted methods were: + +```python +@remote_function() +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', '') + + has_adapter = ( + os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) + ) + if has_adapter: + self.load(checkpoint_dir, adapter_name=adapter_name) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + if not resume_only_model: + adapter_name = adapter_name or self._get_default_group() + optimizer_config = self.optimizer_group[adapter_name] + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._load_scaler_state(checkpoint_dir) + self._load_rng_state(checkpoint_dir) + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +- [ ] **Step 4: Verify `json` and `os` imports exist** + +`json` is imported at line 4, `os` at line 6. No changes needed. + +- [ ] **Step 5: Commit** + +```bash +git add src/twinkle/model/transformers/transformers.py +git commit -m "feat(transformers): replace load_training_state/read_training_progress with resume_from_checkpoint" +``` + +--- + +### Task 3: Implement `resume_from_checkpoint` in MegatronModel + update `save` + +**Files:** +- Modify: `src/twinkle/model/megatron/megatron.py:762-821` (save), add new method after `load` + +- [ ] **Step 1: Update `save()` to write `trainer_state.json`** + +In `src/twinkle/model/megatron/megatron.py`, find the `if save_optimizer:` block (around line 810). After the `_save_mcore_optimizer` call and before the barrier, add: + +```python + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + with open(state_path, 'w') as f: + json.dump(trainer_state, f, indent=2) +``` + +- [ ] **Step 2: Add `resume_from_checkpoint` method** + +Insert after the `load` method (after line 867): + +```python +@remote_function(dispatch='all') +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', self._get_default_group()) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, + adapter_name=adapter_name, **kwargs) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +- [ ] **Step 3: Verify `json` import exists** + +`json` is imported at line 3. No changes needed. + +- [ ] **Step 4: Commit** + +```bash +git add src/twinkle/model/megatron/megatron.py +git commit -m "feat(megatron): add resume_from_checkpoint and save trainer_state.json" +``` + +--- + +### Task 4: Add `resume_from_checkpoint` to DataLoader + +**Files:** +- Modify: `src/twinkle/dataloader/dataloader.py` (after `skip_consumed_samples`, around line 152) + +- [ ] **Step 1: Add method after `skip_consumed_samples`** + +```python +@remote_function() +def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + self.skip_consumed_samples(consumed_train_samples) +``` + +- [ ] **Step 2: Commit** + +```bash +git add src/twinkle/dataloader/dataloader.py +git commit -m "feat(dataloader): add resume_from_checkpoint wrapping skip_consumed_samples" +``` + +--- + +## Chunk 2: Server, Client, Types + +### Task 5: Update Pydantic types + +**Files:** +- Modify: `src/twinkle_client/types/model.py:92-105` (request types), `231-233` (response type) + +- [ ] **Step 1: Delete `LoadTrainingStateRequest` (lines 92-97) and `ReadTrainingProgressRequest` (lines 100-105)** + +Remove both request classes. + +- [ ] **Step 2: Add `ResumeFromCheckpointRequest`** + +Insert at the same location: + +```python +class ResumeFromCheckpointRequest(BaseModel): + """Request for /resume_from_checkpoint endpoint.""" + name: str + adapter_name: str = '' + resume_only_model: bool = False +``` + +- [ ] **Step 3: Rename `TrainingProgressResponse` docstring (line 232)** + +Update the docstring from `"Response for /read_training_progress endpoint"` to `"Response for /resume_from_checkpoint endpoint"`. Keep the class name and `result` field unchanged. + +- [ ] **Step 4: Commit** + +```bash +git add src/twinkle_client/types/model.py +git commit -m "feat(types): replace training state request types with ResumeFromCheckpointRequest" +``` + +--- + +### Task 6: Update server endpoints + +**Files:** +- Modify: `src/twinkle/server/model/twinkle_handlers.py:352-402` + +- [ ] **Step 1: Delete `load_training_state` endpoint (lines 352-376)** + +Remove the entire endpoint function. + +- [ ] **Step 2: Delete `read_training_progress` endpoint (lines 378-402)** + +Remove the entire endpoint function. + +- [ ] **Step 3: Add `resume_from_checkpoint` endpoint** + +Insert at the same location, following the existing endpoint pattern: + +```python +@app.post('/twinkle/resume_from_checkpoint', response_model=types.TrainingProgressResponse) +async def resume_from_checkpoint( + request: Request, + body: types.ResumeFromCheckpointRequest, + self: ModelManagement = Depends(self_fn), +): + token = await self._on_request_start(request) + + async def _task(): + checkpoint_dir = self._resolve_checkpoint_dir(body.name) + result = self.model.resume_from_checkpoint( + checkpoint_dir, + resume_only_model=body.resume_only_model, + adapter_name=body.adapter_name or token, + ) + return types.TrainingProgressResponse(result=result) + + return await run_task(self.schedule_task_and_wait(_task, task_type='resume')) +``` + +Note: Check how `load_training_state` resolves `checkpoint_dir` from `body.name` — replicate the same pattern. If there's a `_resolve_checkpoint_dir` helper, use it. Otherwise inline the resolution logic (typically `os.path.join(output_dir, name)` or direct path). + +- [ ] **Step 4: Commit** + +```bash +git add src/twinkle/server/model/twinkle_handlers.py +git commit -m "feat(server): replace training state endpoints with /resume_from_checkpoint" +``` + +--- + +### Task 7: Update client SDK + +**Files:** +- Modify: `src/twinkle_client/model/multi_lora_transformers.py:192-208` +- Modify: `client_tools/client_generator.py:621-637` + +- [ ] **Step 1: Update `src/twinkle_client/model/multi_lora_transformers.py`** + +Delete `load_training_state` (lines 192-199) and `read_training_progress` (lines 201-208). Replace with: + +```python +def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + response = http_post( + url=f'{self.server_url}/resume_from_checkpoint', + json_data={'name': name, 'adapter_name': self.adapter_name, + 'resume_only_model': resume_only_model, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result +``` + +- [ ] **Step 2: Update `client_tools/client_generator.py`** + +Delete `load_training_state` (lines 621-628) and `read_training_progress` (lines 630-637). Replace with the same `resume_from_checkpoint` method as above. + +- [ ] **Step 3: Commit** + +```bash +git add src/twinkle_client/model/multi_lora_transformers.py client_tools/client_generator.py +git commit -m "feat(client): replace training state methods with resume_from_checkpoint" +``` + +--- + +## Chunk 3: Cookbook and Documentation + +### Task 8: Update cookbook examples + +**Files:** +- Modify: `cookbook/transformers/resume_utils.py:16-55` +- Modify: `cookbook/client/twinkle/self_host/self_congnition.py:102-110` + +- [ ] **Step 1: Rewrite `resume_from_checkpoint` in `cookbook/transformers/resume_utils.py`** + +The old helper function manually orchestrated model + dataloader state. Replace the function body (lines 16-55) with a simplified version that delegates to the new model API: + +```python +def resume_from_checkpoint(model, dataloader, checkpoint_path, *, resume_only_model=False, + ignore_data_skip=False, adapter_name=None) -> int: + kwargs = {} + if adapter_name: + kwargs['adapter_name'] = adapter_name + + progress = model.resume_from_checkpoint( + checkpoint_path, resume_only_model=resume_only_model, **kwargs) + + consumed_train_samples = int(progress.get('consumed_train_samples', 0)) + if not ignore_data_skip and consumed_train_samples > 0: + dataloader.resume_from_checkpoint(consumed_train_samples) + + return consumed_train_samples +``` + +This keeps the helper for backward compatibility with existing cookbook scripts that call it, but the implementation now delegates to the model's own method. + +- [ ] **Step 2: Update `cookbook/client/twinkle/self_host/self_congnition.py`** + +Replace the resume block (around lines 102-110): + +```python +# Before: +consumed_train_samples = 0 +global_step = 0 +if resume_path: + logger.info(f'Resuming model weights from {resume_path}') + model.load(resume_path) + trainer_state = model.load_training_state(resume_path) + dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) + consumed_train_samples = int(trainer_state['consumed_train_samples']) + global_step = int(trainer_state['cur_step']) +``` + +With: + +```python +consumed_train_samples = 0 +global_step = 0 +if resume_path: + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + consumed_train_samples = int(progress['consumed_train_samples']) + global_step = int(progress['cur_step']) +``` + +- [ ] **Step 3: Commit** + +```bash +git add cookbook/transformers/resume_utils.py cookbook/client/twinkle/self_host/self_congnition.py +git commit -m "refactor(cookbook): use model.resume_from_checkpoint API" +``` + +--- + +### Task 9: Update documentation + +**Files:** +- Modify: `docs/source_en/Components/Model/TransformersModel.md:54-65` +- Modify: `docs/source_zh/组件/模型/TransformersModel.md:54-65` +- Modify: `docs/source_en/Usage Guide/Quick-Start.md:289-296` +- Modify: `docs/source_zh/使用指引/快速开始.md:290-297` +- Modify: `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md:141,191` +- Modify: `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md:141,181` + +- [ ] **Step 1: Update English TransformersModel.md (lines 54-65)** + +Replace the checkpoint section with: + +```markdown +### Checkpoint and Resume + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. +- `model.resume_from_checkpoint(checkpoint_dir)` restores full training state (weights, optimizer, scheduler, scaler, RNG) and returns `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`. +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` loads weights only and returns progress metadata without restoring optimizer state. +- `dataloader.resume_from_checkpoint(consumed_train_samples)` skips already-consumed samples. +``` + +- [ ] **Step 2: Update Chinese TransformersModel.md (lines 54-65)** + +Mirror the English changes in Chinese: + +```markdown +### 检查点保存与续训 + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 保存权重、优化器、调度器、scaler、RNG 状态和 `trainer_state.json`。 +- `model.resume_from_checkpoint(checkpoint_dir)` 恢复完整训练状态(权重、优化器、调度器、scaler、RNG),返回 `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`。 +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` 仅加载权重并返回进度元数据,不恢复优化器状态。 +- `dataloader.resume_from_checkpoint(consumed_train_samples)` 跳过已消费的样本。 +``` + +- [ ] **Step 3: Update Quick-Start docs (EN and ZH)** + +In both `docs/source_en/Usage Guide/Quick-Start.md` and `docs/source_zh/使用指引/快速开始.md`, replace `model.load_training_state(resume_path)` references with: + +```python +progress = model.resume_from_checkpoint(resume_path) +dataloader.resume_from_checkpoint(progress['consumed_train_samples']) +``` + +Update the explanatory text accordingly. + +- [ ] **Step 4: Update Twinkle-Client docs (EN and ZH)** + +In both `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md`, replace `model.load_training_state(resume_path)` references with `model.resume_from_checkpoint(resume_path)`. + +- [ ] **Step 5: Commit** + +```bash +git add docs/ +git commit -m "docs: update checkpoint/resume documentation for unified API" +``` + +--- + +### Task 10: Final grep verification + +- [ ] **Step 1: Verify no stale references remain** + +```bash +grep -rn "load_training_state\|read_training_progress" src/ cookbook/ client_tools/ docs/ --include="*.py" --include="*.md" +``` + +Expected: Only hits in `docs/superpowers/` (our spec/plan files). No hits in source code, cookbook, or user-facing docs. + +- [ ] **Step 2: Run pre-commit hooks** + +```bash +pre-commit run --all-files +``` + +Fix any formatting issues (isort, yapf, flake8). + +- [ ] **Step 3: Final commit if needed** + +```bash +git add -A +git commit -m "chore: fix formatting after resume API refactor" +``` diff --git a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md new file mode 100644 index 00000000..a9f62b20 --- /dev/null +++ b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md @@ -0,0 +1,181 @@ +# Unified `resume_from_checkpoint` API Design + +## Problem + +The current checkpoint resume API on the `resume_from_ckpt` branch exposes two similar methods (`load_training_state` and `read_training_progress`) that are hard to distinguish. The caller must manually orchestrate state restoration across model and dataloader, acting as a data courier between components. Additionally, the Megatron backend lacks these methods entirely, creating an asymmetric API surface. + +## Design Principle + +Each component is responsible for its own state restoration. The caller only orchestrates — it does not transport data between components. + +## Target API + +```python +progress = model.resume_from_checkpoint(checkpoint_path) +dataloader.resume_from_checkpoint(progress['consumed_train_samples']) +``` + +Two lines. Both backends. No `resume_utils.py` helper needed. + +## Return Value Contract + +`model.resume_from_checkpoint()` returns a dict with exactly these keys: + +```python +{ + 'cur_step': int, # optimizer step count + 'consumed_train_samples': int, # total samples consumed + 'gradient_accumulation_steps': int, # GAS value at save time +} +``` + +Backend-specific state (optimizer tensors, scaler, RNG, mcore sharded state) is restored internally and not exposed. + +## Component Changes + +### 1. TwinkleModel Base Class (`src/twinkle/model/base.py`) + +Add abstract method: + +```python +@abstractmethod +def resume_from_checkpoint( + self, + checkpoint_dir: str, + *, + resume_only_model: bool = False, + **kwargs, +) -> Dict[str, Any]: + ... +``` + +Parameters: +- `checkpoint_dir`: Path to the checkpoint directory. +- `resume_only_model`: If True, load weights only — skip optimizer/scheduler/RNG restoration. Useful for fine-tuning with a different optimizer config. +- `**kwargs`: Backend-specific args (e.g., `adapter_name`). + +### 2. TransformersModel (`src/twinkle/model/transformers/transformers.py`) + +Delete public methods: `load_training_state()`, `read_training_progress()`. + +Retain private helpers: `_save_training_state()`, `_load_optimizer()`, `_load_scaler_state()`, `_load_rng_state()`, `_get_training_rng_state()`. + +New implementation: + +```python +@remote_function() +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', '') + + # Load adapter weights if checkpoint contains adapter files. + has_adapter = ( + os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) + ) + if has_adapter: + self.load(checkpoint_dir, adapter_name=adapter_name) + + # Read trainer_state.json. + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + # Full restore: optimizer, scheduler, scaler, RNG. + if not resume_only_model: + optimizer_group = self._get_optimizer_group(adapter_name) + self._load_optimizer(checkpoint_dir, optimizer_group, adapter_name) + self._load_scaler_state(checkpoint_dir) + self._load_rng_state(checkpoint_dir) + optimizer_group.cur_step = trainer_state['cur_step'] + optimizer_group.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +Full-parameter training: weights are loaded at model initialization time, so `has_adapter` is False and `self.load()` is skipped. Only training state is restored. + +### 3. MegatronModel (`src/twinkle/model/megatron/megatron.py`) + +**save() change:** When `save_optimizer=True`, also write `trainer_state.json`: + +```python +if save_optimizer: + self._save_mcore_optimizer(checkpoint_dir, optimizer_config=optimizer_config, **kwargs) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if self.device_mesh.rank == 0: + with open(state_path, 'w') as f: + json.dump(trainer_state, f, indent=2) +``` + +**New resume_from_checkpoint():** + +```python +@remote_function(dispatch='all') +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', self._get_default_group()) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, + adapter_name=adapter_name, **kwargs) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +Megatron's `load(load_optimizer=True)` already restores optimizer/scheduler/RNG/cur_step via `_load_mcore_optimizer`. The `resume_from_checkpoint` wrapper adds `trainer_state.json` reading for `consumed_train_samples`. + +### 4. DataLoader (`src/twinkle/dataloader/dataloader.py`) + +New method: + +```python +def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + self.skip_consumed_samples(consumed_train_samples) +``` + +`skip_consumed_samples` is retained as-is (not renamed) for backward compatibility. `resume_from_checkpoint` is the recommended public API going forward. + +### 5. Server Endpoints (`src/twinkle/server/model/twinkle_handlers.py`) + +- Delete: `/twinkle/load_training_state`, `/twinkle/read_training_progress` +- Add: `/twinkle/resume_from_checkpoint` accepting `checkpoint_dir` and `resume_only_model` parameters + +### 6. Client SDK (`src/twinkle_client/`, `client_tools/client_generator.py`) + +- Delete: `load_training_state()`, `read_training_progress()` client methods +- Add: `resume_from_checkpoint()` client method + +### 7. Cookbook Changes + +- Delete `resume_from_checkpoint()` helper from `cookbook/transformers/resume_utils.py` (functionality now lives in the model) +- Update all cookbook examples to use the new two-line API + +### 8. Documentation + +Update `docs/source_en/Components/Model/TransformersModel.md` and corresponding Chinese docs to reflect the new API. + +## Migration Summary + +| Before | After | +|--------|-------| +| `model.load(path)` | `progress = model.resume_from_checkpoint(path)` | +| `model.load_training_state(path)` | (merged into above) | +| `model.read_training_progress(path)` | `progress = model.resume_from_checkpoint(path, resume_only_model=True)` | +| `dataloader.skip_consumed_samples(n)` | `dataloader.resume_from_checkpoint(n)` | +| `resume_from_checkpoint(model, dataloader, ...)` (cookbook util) | Two-line inline call | diff --git a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md new file mode 100644 index 00000000..3323170d --- /dev/null +++ b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md @@ -0,0 +1,181 @@ +# 统一的 `resume_from_checkpoint` API 设计 + +## 问题 + +当前在 `resume_from_ckpt` 分支上的断点续训 API 暴露了两个相似的方法(`load_training_state` 和 `read_training_progress`),难以区分。调用方必须手动编排模型和数据加载器之间的状态恢复,充当组件之间的数据搬运工。此外,Megatron 后端完全没有这些方法,导致 API 表面不对称。 + +## 设计原则 + +每个组件负责自身的状态恢复。调用方只负责编排 —— 不在组件之间搬运数据。 + +## 目标 API + +```python +progress = model.resume_from_checkpoint(checkpoint_path) +dataloader.resume_from_checkpoint(progress['consumed_train_samples']) +``` + +两行代码。两个后端。不再需要 `resume_utils.py` 辅助工具。 + +## 返回值约定 + +`model.resume_from_checkpoint()` 返回一个 dict,包含以下确切的键: + +```python +{ + 'cur_step': int, # 优化器步数 + 'consumed_train_samples': int, # 已消耗的总样本数 + 'gradient_accumulation_steps': int, # 保存时的 GAS 值 +} +``` + +后端特定的状态(优化器张量、scaler、RNG、mcore 分片状态)在内部恢复,不对外暴露。 + +## 组件变更 + +### 1. TwinkleModel 基类 (`src/twinkle/model/base.py`) + +添加抽象方法: + +```python +@abstractmethod +def resume_from_checkpoint( + self, + checkpoint_dir: str, + *, + resume_only_model: bool = False, + **kwargs, +) -> Dict[str, Any]: + ... +``` + +参数说明: +- `checkpoint_dir`: 检查点目录的路径。 +- `resume_only_model`: 如果为 True,则仅加载权重 —— 跳过优化器/调度器/RNG 的恢复。适用于使用不同优化器配置进行微调的场景。 +- `**kwargs`: 后端特定的参数(例如 `adapter_name`)。 + +### 2. TransformersModel (`src/twinkle/model/transformers/transformers.py`) + +删除公共方法:`load_training_state()`、`read_training_progress()`。 + +保留私有辅助方法:`_save_training_state()`、`_load_optimizer()`、`_load_scaler_state()`、`_load_rng_state()`、`_get_training_rng_state()`。 + +新实现: + +```python +@remote_function() +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', '') + + # 如果检查点包含适配器文件,则加载适配器权重。 + has_adapter = ( + os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) + ) + if has_adapter: + self.load(checkpoint_dir, adapter_name=adapter_name) + + # 读取 trainer_state.json。 + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + # 完整恢复:优化器、调度器、scaler、RNG。 + if not resume_only_model: + optimizer_group = self._get_optimizer_group(adapter_name) + self._load_optimizer(checkpoint_dir, optimizer_group, adapter_name) + self._load_scaler_state(checkpoint_dir) + self._load_rng_state(checkpoint_dir) + optimizer_group.cur_step = trainer_state['cur_step'] + optimizer_group.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +全参数训练:权重在模型初始化时加载,因此 `has_adapter` 为 False,`self.load()` 被跳过。仅恢复训练状态。 + +### 3. MegatronModel (`src/twinkle/model/megatron/megatron.py`) + +**save() 变更:** 当 `save_optimizer=True` 时,同时写入 `trainer_state.json`: + +```python +if save_optimizer: + self._save_mcore_optimizer(checkpoint_dir, optimizer_config=optimizer_config, **kwargs) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + if self.device_mesh.rank == 0: + with open(state_path, 'w') as f: + json.dump(trainer_state, f, indent=2) +``` + +**新的 resume_from_checkpoint():** + +```python +@remote_function(dispatch='all') +def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', self._get_default_group()) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, + adapter_name=adapter_name, **kwargs) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } +``` + +Megatron 的 `load(load_optimizer=True)` 已经通过 `_load_mcore_optimizer` 恢复了优化器/调度器/RNG/cur_step。`resume_from_checkpoint` 包装器增加了 `trainer_state.json` 的读取,以获取 `consumed_train_samples`。 + +### 4. 数据加载器 (`src/twinkle/dataloader/dataloader.py`) + +新方法: + +```python +def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + self.skip_consumed_samples(consumed_train_samples) +``` + +`skip_consumed_samples` 保留原样(不更名)以保持向后兼容。`resume_from_checkpoint` 是今后推荐的公共 API。 + +### 5. 服务端接口 (`src/twinkle/server/model/twinkle_handlers.py`) + +- 删除:`/twinkle/load_training_state`、`/twinkle/read_training_progress` +- 新增:`/twinkle/resume_from_checkpoint`,接受 `checkpoint_dir` 和 `resume_only_model` 参数 + +### 6. 客户端 SDK (`src/twinkle_client/`、`client_tools/client_generator.py`) + +- 删除:`load_training_state()`、`read_training_progress()` 客户端方法 +- 新增:`resume_from_checkpoint()` 客户端方法 + +### 7. Cookbook 变更 + +- 删除 `cookbook/transformers/resume_utils.py` 中的 `resume_from_checkpoint()` 辅助函数(功能现已内置于模型中) +- 更新所有 cookbook 示例以使用新的两行 API + +### 8. 文档 + +更新 `docs/source_en/Components/Model/TransformersModel.md` 及对应的中文文档,以反映新的 API。 + +## 迁移摘要 + +| 之前 | 之后 | +|--------|-------| +| `model.load(path)` | `progress = model.resume_from_checkpoint(path)` | +| `model.load_training_state(path)` | (合并到上方) | +| `model.read_training_progress(path)` | `progress = model.resume_from_checkpoint(path, resume_only_model=True)` | +| `dataloader.skip_consumed_samples(n)` | `dataloader.resume_from_checkpoint(n)` | +| `resume_from_checkpoint(model, dataloader, ...)` (cookbook 工具函数) | 两行内联调用 | From 597cbd9c7daca139fb0312fcf4fbf7906828e9cb Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:50:42 +0800 Subject: [PATCH 39/60] fix: remove stale load_training_state references from __init__.py, multi_lora, and docs --- docs/source_en/Usage Guide/Quick-Start.md | 2 +- .../\345\277\253\351\200\237\345\274\200\345\247\213.md" | 2 +- src/twinkle/model/transformers/multi_lora_transformers.py | 2 +- src/twinkle_client/types/__init__.py | 3 +-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 28bb8a12..46ded616 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -478,7 +478,7 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. -Checkpoint resumption is also supported in client-server training. The recommended flow is to restore weights with `model.load(resume_path)`, then restore optimizer and progress metadata with `model.load_training_state(resume_path)`, and finally call `dataloader.skip_consumed_samples(...)`. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. +Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 07eccc95..0e075bb8 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -477,7 +477,7 @@ python train.py ``` ### 远程训练 -client-server 训练场景同样支持断点续训。推荐流程是先通过 `model.load(resume_path)` 恢复权重,再通过 `model.load_training_state(resume_path)` 恢复优化器和训练进度元数据,最后调用 `dataloader.skip_consumed_samples(...)` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 +client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。 diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 6b9fcb2c..e8a1ca43 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -251,7 +251,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self.load_training_state(checkpoint_dir, adapter_name=adapter_name) + self.resume_from_checkpoint(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 0e5d37e1..380a2d9d 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -23,13 +23,12 @@ GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, - LoadTrainingStateRequest, LoadRequest, LoadResponse, LrStepResponse, ModelResult, OkResponse, - ReadTrainingProgressRequest, + ResumeFromCheckpointRequest, SaveRequest, SaveResponse, SetLossRequest, From c55ab9f15e7f9f20a14dd4f939dc24f74cceaaa1 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 11:57:33 +0800 Subject: [PATCH 40/60] fix(transformers): pass correct file paths to _load_scaler_state and _load_rng_state --- src/twinkle/model/transformers/transformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index fac596e3..a91a1c89 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1078,8 +1078,10 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k adapter_name = adapter_name or self._get_default_group() optimizer_config = self.optimizer_group[adapter_name] self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) - self._load_scaler_state(checkpoint_dir) - self._load_rng_state(checkpoint_dir) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) + self._load_rng_state(os.path.join(checkpoint_dir, 'rng_state.pt')) optimizer_config.cur_step = trainer_state['cur_step'] optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] From 8f76b7bd96a8697067e78f49fb99a01923a0647a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 12:17:37 +0800 Subject: [PATCH 41/60] fix: guard rng_state.pt existence check, add Config extra=allow to ResumeFromCheckpointRequest --- src/twinkle/model/transformers/transformers.py | 4 +++- src/twinkle_client/types/model.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a91a1c89..4271b0a5 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1081,7 +1081,9 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') if os.path.exists(scaler_path) and optimizer_config.scaler is not None: self._load_scaler_state(scaler_path, adapter_name=adapter_name) - self._load_rng_state(os.path.join(checkpoint_dir, 'rng_state.pt')) + rng_path = os.path.join(checkpoint_dir, 'rng_state.pt') + if os.path.exists(rng_path): + self._load_rng_state(rng_path) optimizer_config.cur_step = trainer_state['cur_step'] optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index 18dae7b7..a0afa945 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -95,6 +95,9 @@ class ResumeFromCheckpointRequest(BaseModel): adapter_name: str = '' resume_only_model: bool = False + class Config: + extra = 'allow' + class AddAdapterRequest(BaseModel): adapter_name: str From 4ffa5c7858210ef04acc78cea073f49873c2a08f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 16:25:56 +0800 Subject: [PATCH 42/60] wip --- .../transformers/multi_lora_transformers.py | 10 +++-- .../model/transformers/transformers.py | 38 +++++++++++-------- src/twinkle_client/dataloader/dataloader.py | 14 +++++++ 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index e8a1ca43..72f139a6 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -239,9 +239,11 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k with self.multi_adapter.save_context(kwargs.get('adapter_name')): load_optimizer = kwargs.get('load_optimizer', False) if output_dir is None: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: checkpoint_dir = os.path.join(output_dir, name) model = self.strategy.unwrap_model(self.model) @@ -251,7 +253,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self.resume_from_checkpoint(checkpoint_dir, adapter_name=adapter_name) + self._restore_training_state(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 4271b0a5..f01044af 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1059,6 +1059,25 @@ def _load_rng_state(self, rng_path): if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): device_module.set_rng_state(device_rng_state) + def _restore_training_state(self, checkpoint_dir, *, adapter_name=''): + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path, 'r') as f: + trainer_state = json.load(f) + + adapter_name = adapter_name or self._get_default_group() + optimizer_config = self.optimizer_group[adapter_name] + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) + rng_path = os.path.join(checkpoint_dir, 'rng_state.pt') + if os.path.exists(rng_path): + self._load_rng_state(rng_path) + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return trainer_state + @remote_function() def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): adapter_name = kwargs.get('adapter_name', '') @@ -1070,22 +1089,11 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k if has_adapter: self.load(checkpoint_dir, adapter_name=adapter_name) - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - if not resume_only_model: - adapter_name = adapter_name or self._get_default_group() - optimizer_config = self.optimizer_group[adapter_name] - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) - scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') - if os.path.exists(scaler_path) and optimizer_config.scaler is not None: - self._load_scaler_state(scaler_path, adapter_name=adapter_name) - rng_path = os.path.join(checkpoint_dir, 'rng_state.pt') - if os.path.exists(rng_path): - self._load_rng_state(rng_path) - optimizer_config.cur_step = trainer_state['cur_step'] - optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + trainer_state = self._restore_training_state(checkpoint_dir, adapter_name=adapter_name) + else: + with open(os.path.join(checkpoint_dir, 'trainer_state.json')) as f: + trainer_state = json.load(f) return { 'cur_step': trainer_state['cur_step'], diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index f6a24fe4..d7d08bb4 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -95,4 +95,18 @@ def skip_consumed_samples(self, consumed_train_samples: int): ) response.raise_for_status() return response.json()["result"] + + + def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'resume_from_checkpoint', + **{'consumed_train_samples': consumed_train_samples}, + **kwargs + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file From 0b43055f670761c2978793373220a202f4a58dfa Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 16:34:04 +0800 Subject: [PATCH 43/60] wip --- .../transformers/strategy/sequence_parallel.py | 11 +++++++++++ src/twinkle/model/transformers/transformers.py | 13 +++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py index 64ea34f3..dc7cbeae 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py @@ -1039,3 +1039,14 @@ def wrap_model(self, model, optimizer=None): def unwrap_model(self, model): return model + + def needs_wrapped_optimizer_state(self) -> bool: + return False + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + from twinkle.utils.platforms import Platform + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index f01044af..5912e0a1 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -892,10 +892,7 @@ def _save_optimizer(self, output_dir, **kwargs): if optimizer is not None: optimizer_path = os.path.join(output_dir, 'optimizer.pt') - if hasattr(self.strategy, 'save_optimizer_checkpoint'): - self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) - elif Platform.is_master(): - torch.save(optimizer.state_dict(), optimizer_path) + self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) if Platform.is_master(): if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) @@ -1007,13 +1004,9 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', ) if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - if getattr(self.strategy, 'needs_wrapped_optimizer_state', lambda: False)() and not self._model_wrapped: + if self.strategy.needs_wrapped_optimizer_state() and not self._model_wrapped: self._lazy_wrap_model() - if hasattr(self.strategy, 'load_optimizer_checkpoint'): - self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) - else: - state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True) - optimizer_config.optimizer.load_state_dict(state_dict) + self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=True) From c8bc9ab3d7903b56e1dd897996e0198bd0516a01 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 16:42:02 +0800 Subject: [PATCH 44/60] wip --- src/twinkle/model/transformers/strategy/accelerate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index b9b81f59..d1e4984d 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -1,6 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Any, Dict, Literal, Optional +import torch + from twinkle import DeviceMesh from .load_context import fsdp_pretrained_load_context @@ -147,9 +149,9 @@ def needs_wrapped_optimizer_state(self) -> bool: return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - import torch from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) @@ -157,14 +159,13 @@ def save_optimizer_checkpoint(self, model, optimizer, output_path: str): torch.save(optim_state, output_path) return - import torch if self.accelerator.process_index == 0: torch.save(optimizer.state_dict(), output_path) def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch fsdp_plugin = self._get_fsdp_plugin() if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: - import torch from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict optim_state = None @@ -174,7 +175,6 @@ def load_optimizer_checkpoint(self, model, optimizer, input_path: str): set_optimizer_state_dict(model, optimizer, optim_state, options=self._prepare_fsdp2_sd_options()) return - import torch optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) def get_full_state_dict(self, model) -> dict: From 8c0399e4c8d6f98d5340436dbfcff2c4d73e5e46 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 16:46:55 +0800 Subject: [PATCH 45/60] wip --- src/twinkle/model/base.py | 6 +++++- src/twinkle/model/megatron/megatron.py | 5 ++--- src/twinkle/model/transformers/strategy/accelerate.py | 2 -- src/twinkle/model/transformers/transformers.py | 8 ++------ 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index bc132527..7a5f9e5c 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -88,7 +88,11 @@ def get_state_dict(self, **kwargs) -> Dict[str, Any]: ... @abstractmethod - def resume_from_checkpoint(self, checkpoint_dir: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + def resume_from_checkpoint(self, + checkpoint_dir: str, + *, + resume_only_model: bool = False, + **kwargs) -> Dict[str, Any]: ... @abstractmethod diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 86263ee4..14020b47 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -882,11 +882,10 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k adapter_name = kwargs.get('adapter_name', self._get_default_group()) trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: + with open(trainer_state_path) as f: trainer_state = json.load(f) - self.load(checkpoint_dir, load_optimizer=not resume_only_model, - adapter_name=adapter_name, **kwargs) + self.load(checkpoint_dir, load_optimizer=not resume_only_model, adapter_name=adapter_name, **kwargs) return { 'cur_step': trainer_state['cur_step'], diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d1e4984d..8d31291a 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -1,8 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Any, Dict, Literal, Optional -import torch - from twinkle import DeviceMesh from .load_context import fsdp_pretrained_load_context diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5912e0a1..92360132 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -42,9 +42,6 @@ from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm -from twinkle.utils.logger import get_logger - -logger = get_logger() logger = get_logger() @@ -1054,7 +1051,7 @@ def _load_rng_state(self, rng_path): def _restore_training_state(self, checkpoint_dir, *, adapter_name=''): trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: + with open(trainer_state_path) as f: trainer_state = json.load(f) adapter_name = adapter_name or self._get_default_group() @@ -1077,8 +1074,7 @@ def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **k has_adapter = ( os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) - or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) - ) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin'))) if has_adapter: self.load(checkpoint_dir, adapter_name=adapter_name) From 10b4a20008ce78994091aaabb538e824bcd81de1 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 21 Apr 2026 17:09:07 +0800 Subject: [PATCH 46/60] refactor: delete resume_utils.py, inline logic in fsdp2.py, update docs --- cookbook/transformers/fsdp2.py | 18 ++++++------- cookbook/transformers/resume_utils.py | 27 ------------------- .../Components/Model/TransformersModel.md | 2 +- docs/source_en/Usage Guide/Quick-Start.md | 23 +++++++--------- ...53\351\200\237\345\274\200\345\247\213.md" | 23 +++++++--------- .../TransformersModel.md" | 2 +- 6 files changed, 28 insertions(+), 67 deletions(-) delete mode 100644 cookbook/transformers/resume_utils.py diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 45dd8ac1..5a61ff0a 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -10,8 +10,6 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -from resume_utils import resume_from_checkpoint - logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen3.5-4B' @@ -88,14 +86,14 @@ def train(): consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - consumed_train_samples = resume_from_checkpoint( - model=model, - dataloader=dataloader, - checkpoint_path=checkpoint_path, - resume_only_model=RESUME_ONLY_MODEL, - ignore_data_skip=IGNORE_DATA_SKIP, - adapter_name=ADAPTER_NAME, - ) + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + consumed_train_samples = int(progress.get('consumed_train_samples', 0)) + if not IGNORE_DATA_SKIP and consumed_train_samples > 0: + dataloader.resume_from_checkpoint(consumed_train_samples) logger.info(get_device_placement()) # Print the training config diff --git a/cookbook/transformers/resume_utils.py b/cookbook/transformers/resume_utils.py deleted file mode 100644 index fd87d123..00000000 --- a/cookbook/transformers/resume_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -from pathlib import Path -from typing import Any, Optional - -from twinkle import get_logger - -logger = get_logger() - - -def resume_from_checkpoint(model: Any, - dataloader: Any, - checkpoint_path: Path, - *, - resume_only_model: bool, - ignore_data_skip: bool, - adapter_name: Optional[str] = None) -> int: - kwargs = {} - if adapter_name: - kwargs['adapter_name'] = adapter_name - - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=resume_only_model, **kwargs) - - consumed_train_samples = int(progress.get('consumed_train_samples', 0)) - if not ignore_data_skip and consumed_train_samples > 0: - dataloader.resume_from_checkpoint(consumed_train_samples) - - return consumed_train_samples diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index 1caab30c..5a071ce4 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -62,4 +62,4 @@ for data in dataloader: For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. -For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. +For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 46ded616..eae40356 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -232,7 +232,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ### Resume from Checkpoint -The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py` together with `cookbook/transformers/resume_utils.py`. +The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py`. When saving a checkpoint intended for resumption, save both model weights and training progress: @@ -256,28 +256,23 @@ To resume training, restore the checkpoint before entering the main loop: ```python from pathlib import Path -from resume_utils import resume_from_checkpoint - RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' RESUME_ONLY_MODEL = False IGNORE_DATA_SKIP = False consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: - consumed_train_samples = resume_from_checkpoint( - model=model, - dataloader=dataloader, - checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), - resume_only_model=RESUME_ONLY_MODEL, - ignore_data_skip=IGNORE_DATA_SKIP, - adapter_name='default', - ) + checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) + progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) + consumed_train_samples = int(progress.get('consumed_train_samples', 0)) + if not IGNORE_DATA_SKIP and consumed_train_samples > 0: + dataloader.resume_from_checkpoint(consumed_train_samples) ``` -This helper provides two common resume modes: +This covers two common resume modes: -- Full resume: restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. -- Weights-only resume: restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. +- Full resume (default): restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. +- Weights-only resume (`resume_only_model=True`): restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. When `RESUME_ONLY_MODEL=True`, `IGNORE_DATA_SKIP=False` still skips already consumed samples based on `trainer_state.json`. If you want to reload weights but restart the dataset from the beginning, set `IGNORE_DATA_SKIP=True`. diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 0e075bb8..e4b9022b 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -233,7 +233,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ### 断点续训 -上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 +上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py`。 如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: @@ -257,28 +257,23 @@ def save_checkpoint(model, checkpoint_name): ```python from pathlib import Path -from resume_utils import resume_from_checkpoint - RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' RESUME_ONLY_MODEL = False IGNORE_DATA_SKIP = False consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: - consumed_train_samples = resume_from_checkpoint( - model=model, - dataloader=dataloader, - checkpoint_path=Path(RESUME_FROM_CHECKPOINT).expanduser().resolve(), - resume_only_model=RESUME_ONLY_MODEL, - ignore_data_skip=IGNORE_DATA_SKIP, - adapter_name='default', - ) + checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) + progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) + consumed_train_samples = int(progress.get('consumed_train_samples', 0)) + if not IGNORE_DATA_SKIP and consumed_train_samples > 0: + dataloader.resume_from_checkpoint(consumed_train_samples) ``` -这个辅助函数覆盖了两种常见恢复模式: +这两行 API 覆盖了两种常见恢复模式: -- 完整续训:恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 -- 仅恢复权重:只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 +- 完整续训(默认):恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 +- 仅恢复权重(`resume_only_model=True`):只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 当 `RESUME_ONLY_MODEL=True` 且 `IGNORE_DATA_SKIP=False` 时,仍会根据 `trainer_state.json` 跳过已训练过的数据;如果你只想加载权重、但从数据集开头重新训练,可以把 `IGNORE_DATA_SKIP=True`。 diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index f3816ade..5aeead91 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -62,4 +62,4 @@ for data in dataloader: 对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 -如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py` 和 `cookbook/transformers/resume_utils.py`。 +如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py`。 From 3df191a87c301c234223ab8cb994a3c93b925b62 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 23 Apr 2026 09:37:47 +0800 Subject: [PATCH 47/60] wip --- .../twinkle/self_host/self_cognition.py | 13 +- cookbook/megatron/tp_resume.py | 113 ++++++++++++++++++ cookbook/transformers/fsdp2.py | 15 +-- cookbook/transformers/fsdp2_full.py | 112 +++++++++++++++++ .../Components/Model/TransformersModel.md | 1 + docs/source_en/Usage Guide/Quick-Start.md | 14 +-- .../Server and Client/Twinkle-Client.md | 4 - ...53\351\200\237\345\274\200\345\247\213.md" | 14 +-- ...le\345\256\242\346\210\267\347\253\257.md" | 4 - .../TransformersModel.md" | 1 + src/twinkle/dataloader/dataloader.py | 13 +- src/twinkle_client/dataloader/dataloader.py | 13 ++ 12 files changed, 272 insertions(+), 45 deletions(-) create mode 100644 cookbook/megatron/tp_resume.py create mode 100644 cookbook/transformers/fsdp2_full.py diff --git a/cookbook/client/twinkle/self_host/self_cognition.py b/cookbook/client/twinkle/self_host/self_cognition.py index 61a9224f..a8010586 100644 --- a/cookbook/client/twinkle/self_host/self_cognition.py +++ b/cookbook/client/twinkle/self_host/self_cognition.py @@ -99,14 +99,10 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint - consumed_train_samples = 0 - global_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - consumed_train_samples = int(progress['consumed_train_samples']) - global_step = int(progress['cur_step']) # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) @@ -119,8 +115,6 @@ def train(): # Step model.clip_grad_and_step() - consumed_train_samples += len(batch) - global_step += 1 # Equal to the following steps: # # Clip gradients to prevent exploding gradients (max norm = 1.0) # model.clip_grad_norm(1.0) @@ -132,16 +126,17 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - if global_step % 2 == 0: + cur_step = dataloader.get_state()['consumed_train_samples'] // 4 + if cur_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint twinkle_path = model.save( name=f'twinkle-epoch-{epoch}', save_optimizer=True, - consumed_train_samples=consumed_train_samples, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) logger.info(f'Saved checkpoint: {twinkle_path}') diff --git a/cookbook/megatron/tp_resume.py b/cookbook/megatron/tp_resume.py new file mode 100644 index 00000000..a099789e --- /dev/null +++ b/cookbook/megatron/tp_resume.py @@ -0,0 +1,113 @@ +from pathlib import Path + +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +DP_SIZE = 2 +TP_SIZE = 2 +PP_SIZE = 2 +BATCH_SIZE = 16 +LEARNING_RATE = 1e-4 +LOG_INTERVAL = 5 +EVAL_INTERVAL = 20 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +OUTPUT_DIR = './output/megatron_tp' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) + dataset.encode() + return dataset + + +def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): + model.forward_only(inputs=batch) + return model.calculate_metric(is_training=False) + + +def train(): + dataset = build_dataset(TRAIN_SAMPLES) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + model = MegatronModel(model_id=MODEL_ID) + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model(ADAPTER_NAME, lora_config) + model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + + best_loss = float('inf') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + if step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + if step > 0 and step % EVAL_INTERVAL == 0: + metrics = evaluate(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 5a61ff0a..450906c5 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -47,13 +47,13 @@ def build_dataset(num_samples: int) -> Dataset: return dataset -def save_checkpoint(model: TransformersModel, checkpoint_name: str, consumed_train_samples: int): +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): model.save( checkpoint_name, output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME, save_optimizer=True, - consumed_train_samples=consumed_train_samples, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) @@ -83,7 +83,6 @@ def train(): model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) - consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() kwargs = {} @@ -91,9 +90,8 @@ def train(): kwargs['adapter_name'] = ADAPTER_NAME progress = model.resume_from_checkpoint( str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - consumed_train_samples = int(progress.get('consumed_train_samples', 0)) - if not IGNORE_DATA_SKIP and consumed_train_samples > 0: - dataloader.resume_from_checkpoint(consumed_train_samples) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) logger.info(get_device_placement()) # Print the training config @@ -108,7 +106,6 @@ def train(): model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - consumed_train_samples += BATCH_SIZE cur_step = optimizer_group.cur_step if cur_step % LOG_INTERVAL == 0: # Print metric @@ -120,9 +117,9 @@ def train(): metrics['step'] = cur_step current_loss = float(metrics['loss']) if current_loss < best_loss: - save_checkpoint(model, f'checkpoint-{cur_step}', consumed_train_samples) + save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) best_loss = current_loss - save_checkpoint(model, 'last-checkpoint', consumed_train_samples) + save_checkpoint(model, 'last-checkpoint', dataloader) if __name__ == '__main__': diff --git a/cookbook/transformers/fsdp2_full.py b/cookbook/transformers/fsdp2_full.py new file mode 100644 index 00000000..5aebf6ac --- /dev/null +++ b/cookbook/transformers/fsdp2_full.py @@ -0,0 +1,112 @@ +from pathlib import Path + +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +FSDP_SIZE = 2 +DP_SIZE = 1 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-5 +WEIGHT_DECAY = 0.01 +GRADIENT_ACCUMULATION_STEPS = 1 +LOG_INTERVAL = 1 +EVAL_INTERVAL = 20 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +import time +OUTPUT_DIR = f'./output/fsdp2_full_{int(time.time())}' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False + +device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) + dataset.encode() + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): + model.forward_only(inputs=batch) + model.calculate_loss() + return model.calculate_metric(is_training=False) + + +def train(): + dataset = build_dataset(TRAIN_SAMPLES) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + model = TransformersModel(model_id=MODEL_ID) + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + + model.set_optimizer( + optimizer_cls='AdamW', + lr=LEARNING_RATE, + weight_decay=WEIGHT_DECAY, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + progress = model.resume_from_checkpoint(str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + + optimizer_group = model.optimizer_group[''] + best_loss = float('inf') + + for batch in dataloader: + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + cur_step = optimizer_group.cur_step + if cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: + metrics = evaluate(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = cur_step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) + + +if __name__ == '__main__': + train() diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index 5a071ce4..a9005a55 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -59,6 +59,7 @@ for data in dataloader: - `model.resume_from_checkpoint(checkpoint_dir)` restores full training state (weights, optimizer, scheduler, scaler, RNG) and returns `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`. - `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` loads weights only and returns progress metadata without restoring optimizer state. - `dataloader.resume_from_checkpoint(consumed_train_samples)` skips already-consumed samples. +- `dataloader.get_state()` returns `{'consumed_train_samples': int}` — the dataloader automatically tracks consumed samples, so you don't need to maintain a counter manually. For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index eae40356..9f4869b5 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -237,19 +237,17 @@ The local and `torchrun` training loops above can be extended to support checkpo When saving a checkpoint intended for resumption, save both model weights and training progress: ```python -consumed_train_samples = 0 - -def save_checkpoint(model, checkpoint_name): +def save_checkpoint(model, checkpoint_name, dataloader): model.save( checkpoint_name, output_dir='./output/fsdp2', adapter_name='default', save_optimizer=True, - consumed_train_samples=consumed_train_samples, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) ``` -`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. +`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. The `DataLoader` automatically tracks consumed samples internally — call `dataloader.get_state()` to retrieve the current count. To resume training, restore the checkpoint before entering the main loop: @@ -260,13 +258,11 @@ RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' RESUME_ONLY_MODEL = False IGNORE_DATA_SKIP = False -consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) - consumed_train_samples = int(progress.get('consumed_train_samples', 0)) - if not IGNORE_DATA_SKIP and consumed_train_samples > 0: - dataloader.resume_from_checkpoint(consumed_train_samples) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` This covers two common resume modes: diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index b4a1e61e..a4bd971d 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -133,14 +133,10 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) -consumed_train_samples = 0 -global_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - consumed_train_samples = int(progress['consumed_train_samples']) - global_step = int(progress['cur_step']) # Step 6: Training loop logger.info(model.get_train_configs().model_dump()) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index e4b9022b..6dceaa1f 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -238,19 +238,17 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py 如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: ```python -consumed_train_samples = 0 - -def save_checkpoint(model, checkpoint_name): +def save_checkpoint(model, checkpoint_name, dataloader): model.save( checkpoint_name, output_dir='./output/fsdp2', adapter_name='default', save_optimizer=True, - consumed_train_samples=consumed_train_samples, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], ) ``` -其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。 +其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。`DataLoader` 会自动追踪已消费的样本数,通过 `dataloader.get_state()` 即可获取当前计数。 恢复训练时,建议在进入主训练循环之前先加载 checkpoint: @@ -261,13 +259,11 @@ RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' RESUME_ONLY_MODEL = False IGNORE_DATA_SKIP = False -consumed_train_samples = 0 if RESUME_FROM_CHECKPOINT: checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) - consumed_train_samples = int(progress.get('consumed_train_samples', 0)) - if not IGNORE_DATA_SKIP and consumed_train_samples > 0: - dataloader.resume_from_checkpoint(consumed_train_samples) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` 这两行 API 覆盖了两种常见恢复模式: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index 9ee6f0fd..e40ce7e1 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -133,14 +133,10 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) -consumed_train_samples = 0 -global_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - consumed_train_samples = int(progress['consumed_train_samples']) - global_step = int(progress['cur_step']) # Step 6: 训练循环 logger.info(model.get_train_configs().model_dump()) diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index 5aeead91..cd0f16ad 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -59,6 +59,7 @@ for data in dataloader: - `model.resume_from_checkpoint(checkpoint_dir)` 恢复完整训练状态(权重、优化器、调度器、scaler、RNG),返回 `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`。 - `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` 仅加载权重并返回进度元数据,不恢复优化器状态。 - `dataloader.resume_from_checkpoint(consumed_train_samples)` 跳过已消费的样本。 +- `dataloader.get_state()` 返回 `{'consumed_train_samples': int}` — DataLoader 会自动追踪已消费样本数,无需手动维护计数器。 对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index 268fad24..c392d56c 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -55,6 +55,7 @@ def __init__(self, self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None self._skip_samples = 0 + self._consumed_train_samples = 0 self._base_batch_sampler = None self._base_sampler = None self._retry_sampler_seed = self._resolve_retry_sampler_seed() @@ -134,7 +135,12 @@ def __iter__(self): self.batch_size, self.device_mesh, max_retries=self.max_retries) - return _iter + return self._tracking_iter(_iter) + + def _tracking_iter(self, inner): + for batch in inner: + self._consumed_train_samples += self.batch_size + yield batch @remote_function() def skip_consumed_samples(self, consumed_train_samples: int) -> None: @@ -146,6 +152,7 @@ def skip_consumed_samples(self, consumed_train_samples: int) -> None: return self._skip_samples = max(int(consumed_train_samples), 0) + self._consumed_train_samples = self._skip_samples if self.dataloader is not None: self.dataloader.__initialized = False self._rebuild_sampler_stack() @@ -155,6 +162,10 @@ def skip_consumed_samples(self, consumed_train_samples: int) -> None: def resume_from_checkpoint(self, consumed_train_samples, **kwargs): self.skip_consumed_samples(consumed_train_samples) + @remote_function() + def get_state(self) -> dict: + return {'consumed_train_samples': self._consumed_train_samples} + def _rebuild_sampler_stack(self): if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): batch_sampler = copy.copy(self._base_batch_sampler) diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index d7d08bb4..4164940d 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -109,4 +109,17 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs): ) response.raise_for_status() return response.json()["result"] + + + def get_state(self): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'get_state', + **{}, + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file From 7a657e859fc48de44f381224ea7d2d33f0e76546 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 27 Apr 2026 10:21:59 +0800 Subject: [PATCH 48/60] wip --- src/twinkle/model/megatron/megatron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 9f5da895..9fa3a46e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -872,7 +872,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): output_dir = 'output' checkpoint_dir = os.path.join(output_dir, name) - adapter_name = kwargs.get('adapter_name', self._get_default_group()) + adapter_name = kwargs.pop('adapter_name', self._get_default_group()) if resume: self._load_mcore_optimizer( @@ -894,7 +894,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): @remote_function(dispatch='all') def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', self._get_default_group()) + adapter_name = kwargs.pop('adapter_name', self._get_default_group()) trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') with open(trainer_state_path) as f: From ae6712235be47b3b2a4360592b95ff338e3dbe07 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 27 Apr 2026 16:28:32 +0800 Subject: [PATCH 49/60] fix --- .../model/megatron/multi_lora_megatron.py | 139 ++++++++++++++++-- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 04c7cba5..4b2acc87 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -1,15 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os +import random import re -import torch.distributed as dist -import torch.nn as nn from contextlib import contextmanager from functools import partial +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn from peft import LoraConfig from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import AutoConfig, PretrainedConfig -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory @@ -203,10 +208,71 @@ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], self._check_adapter_valid(kwargs.get('adapter_name')) super().set_lr_scheduler(scheduler_cls, **kwargs) + @staticmethod + def _rank_local_optimizer_path(checkpoint_dir: str) -> str: + rank = dist.get_rank() if dist.is_initialized() else 0 + return os.path.join(checkpoint_dir, f'optimizer_rank_{rank}.pt') + + @staticmethod + def _get_local_training_rng_state(): + from megatron.core import tensor_parallel + + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + if torch.cuda.is_available(): + rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() + return rng_state + + @staticmethod + def _restore_local_training_rng_state(rng_state): + from megatron.core import tensor_parallel + + random.setstate(rng_state['random_rng_state']) + np.random.set_state(rng_state['np_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): + torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) + + def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): + os.makedirs(checkpoint_dir, exist_ok=True) + state_dict = { + 'checkpoint_version': 1, + 'iteration': optimizer_config.cur_step, + 'rng_state': self._get_local_training_rng_state(), + } + if optimizer_config.optimizer is not None: + state_dict['optimizer'] = optimizer_config.optimizer.state_dict() + if optimizer_config.lr_scheduler is not None: + state_dict['opt_param_scheduler'] = optimizer_config.lr_scheduler.state_dict() + + torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): + no_load_optim = kwargs.pop('no_load_optim', False) + no_load_rng = kwargs.pop('no_load_rng', False) + optimizer_config = self.optimizer_group.get(adapter_name) + state_dict = torch.load(self._rank_local_optimizer_path(checkpoint_dir), map_location='cpu', weights_only=False) + + if not no_load_optim and optimizer_config is not None: + if optimizer_config.optimizer is not None and 'optimizer' in state_dict: + optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: + optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) + if not no_load_rng and 'rng_state' in state_dict: + self._restore_local_training_rng_state(state_dict['rng_state']) + if optimizer_config is not None and 'iteration' in state_dict: + optimizer_config.cur_step = state_dict['iteration'] + @remote_function(dispatch='all', collect='first', sync=True) def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): - self._check_adapter_valid(kwargs.get('adapter_name')) - optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] + adapter_name = kwargs.pop('adapter_name', None) + self._check_adapter_valid(adapter_name) + optimizer_config = self.optimizer_group[adapter_name] if optimizer_config.cur_step % interval != 0: return @@ -215,8 +281,9 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): if output_dir is None: output_dir = 'output' checkpoint_dir = os.path.join(output_dir, name) + save_optimizer = kwargs.pop('save_optimizer', False) - with self.multi_adapter.save_context(kwargs.get('adapter_name')) as real_adapter_name: + with self.multi_adapter.save_context(adapter_name) as real_adapter_name: save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' # Use partial to bind adapter_name to save_lora_converter lora_converter = partial(self.multi_adapter.save_lora_converter, adapter_name=real_adapter_name) @@ -228,7 +295,25 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): else: self._save_megatron_format(checkpoint_dir, real_adapter_name, lora_converter=lora_converter) - self._save_tokenizer(checkpoint_dir, adapter_name=kwargs.get('adapter_name')) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) + if save_optimizer: + with self.optimizer_context(real_adapter_name): + self._save_multi_lora_optimizer( + checkpoint_dir, + optimizer_config=optimizer_config, + **kwargs, + ) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f: + json.dump(trainer_state, f, indent=2) # Final synchronization to ensure all ranks complete save if dist.is_initialized(): dist.barrier() @@ -237,25 +322,55 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): @remote_function(dispatch='all') def load(self, name: str, output_dir: Optional[str] = None, **kwargs): + load_optimizer = kwargs.pop('load_optimizer', False) + adapter_name = kwargs.pop('adapter_name', None) if output_dir is None: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: checkpoint_dir = os.path.join(output_dir, name) bridge = self.strategy.bridge - with self.multi_adapter.save_context(kwargs.get('adapter_name')) as adapter_name: + with self.multi_adapter.save_context(adapter_name) as real_adapter_name: model = self.strategy.unwrap_model(self.model) bridge.load_weights( model, checkpoint_dir, peft_format=True, - adapter_name=adapter_name, + adapter_name=real_adapter_name, converter=self.multi_adapter.load_lora_converter) + if load_optimizer: + with self.optimizer_context(real_adapter_name): + self._load_multi_lora_optimizer(checkpoint_dir, adapter_name=adapter_name, **kwargs) + if dist.is_initialized(): dist.barrier() + @remote_function(dispatch='all', collect='first', sync=True) + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.pop('adapter_name', None) + self._check_adapter_valid(adapter_name) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path) as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, adapter_name=adapter_name, **kwargs) + + optimizer_config = self.optimizer_group.get(adapter_name) + if not resume_only_model and optimizer_config is not None: + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } + @remote_function(execute='first') def get_state_dict(self, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) From 5b15d6744b63ca0f03bc2e4d98b1513c991028b4 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 27 Apr 2026 16:30:03 +0800 Subject: [PATCH 50/60] lint --- src/twinkle/model/megatron/multi_lora_megatron.py | 9 ++++----- src/twinkle/server/model/twinkle_handlers.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 4b2acc87..4f67a450 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -1,20 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import json +import numpy as np import os import random import re -from contextlib import contextmanager -from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union - -import numpy as np import torch import torch.distributed as dist import torch.nn as nn +from contextlib import contextmanager +from functools import partial from peft import LoraConfig from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import AutoConfig, PretrainedConfig +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index cc1db810..3989fed9 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,8 +8,8 @@ """ from __future__ import annotations -import os import asyncio +import os import torch import traceback from fastapi import Depends, FastAPI, HTTPException, Request From f0d36e2c5088baca90a192901fb2339441e50832 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 27 Apr 2026 16:32:06 +0800 Subject: [PATCH 51/60] remove --- .../plans/2026-04-21-unified-resume-api.md | 460 ------------------ .../2026-04-21-unified-resume-api-design.md | 181 ------- ...2026-04-21-unified-resume-api-design.zh.md | 181 ------- 3 files changed, 822 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-21-unified-resume-api.md delete mode 100644 docs/superpowers/specs/2026-04-21-unified-resume-api-design.md delete mode 100644 docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md diff --git a/docs/superpowers/plans/2026-04-21-unified-resume-api.md b/docs/superpowers/plans/2026-04-21-unified-resume-api.md deleted file mode 100644 index 5e0f9a5d..00000000 --- a/docs/superpowers/plans/2026-04-21-unified-resume-api.md +++ /dev/null @@ -1,460 +0,0 @@ -# Unified `resume_from_checkpoint` API — Implementation Plan - -> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Replace `load_training_state` / `read_training_progress` with a single `resume_from_checkpoint` method on both model backends and dataloader, so callers orchestrate with two lines instead of five. - -**Architecture:** Add `resume_from_checkpoint` as an abstract method on `TwinkleModel`. Each backend (Transformers, Megatron) implements it to restore its own state internally and return a common `{cur_step, consumed_train_samples, gradient_accumulation_steps}` dict. DataLoader gets a matching `resume_from_checkpoint` that wraps `skip_consumed_samples`. Server/client/cookbook/docs updated to match. - -**Tech Stack:** Python, PyTorch, FastAPI, Pydantic, PEFT, Megatron-Core - -**Spec:** `docs/superpowers/specs/2026-04-21-unified-resume-api-design.md` - ---- - -## Chunk 1: Core Model API - -### Task 1: Add `resume_from_checkpoint` to TwinkleModel base class - -**Files:** -- Modify: `src/twinkle/model/base.py:86-88` - -- [ ] **Step 1: Add abstract method after `get_state_dict`** - -In `src/twinkle/model/base.py`, insert after line 88 (`get_state_dict`): - -```python -@abstractmethod -def resume_from_checkpoint(self, checkpoint_dir: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: - ... -``` - -- [ ] **Step 2: Verify no import changes needed** - -`Dict` and `Any` are already imported on line 4. No changes needed. - -- [ ] **Step 3: Commit** - -```bash -git add src/twinkle/model/base.py -git commit -m "feat: add resume_from_checkpoint abstract method to TwinkleModel base" -``` - ---- - -### Task 2: Implement `resume_from_checkpoint` in TransformersModel - -**Files:** -- Modify: `src/twinkle/model/transformers/transformers.py:1063-1100` - -- [ ] **Step 1: Delete `read_training_progress` method (lines 1063-1075)** - -Remove the entire `read_training_progress` method. - -- [ ] **Step 2: Delete `load_training_state` method (lines 1078-1100)** - -Remove the entire `load_training_state` method. - -- [ ] **Step 3: Add `resume_from_checkpoint` method** - -Insert at the same location where the deleted methods were: - -```python -@remote_function() -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', '') - - has_adapter = ( - os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) - or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) - ) - if has_adapter: - self.load(checkpoint_dir, adapter_name=adapter_name) - - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - if not resume_only_model: - adapter_name = adapter_name or self._get_default_group() - optimizer_config = self.optimizer_group[adapter_name] - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) - self._load_scaler_state(checkpoint_dir) - self._load_rng_state(checkpoint_dir) - optimizer_config.cur_step = trainer_state['cur_step'] - optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -- [ ] **Step 4: Verify `json` and `os` imports exist** - -`json` is imported at line 4, `os` at line 6. No changes needed. - -- [ ] **Step 5: Commit** - -```bash -git add src/twinkle/model/transformers/transformers.py -git commit -m "feat(transformers): replace load_training_state/read_training_progress with resume_from_checkpoint" -``` - ---- - -### Task 3: Implement `resume_from_checkpoint` in MegatronModel + update `save` - -**Files:** -- Modify: `src/twinkle/model/megatron/megatron.py:762-821` (save), add new method after `load` - -- [ ] **Step 1: Update `save()` to write `trainer_state.json`** - -In `src/twinkle/model/megatron/megatron.py`, find the `if save_optimizer:` block (around line 810). After the `_save_mcore_optimizer` call and before the barrier, add: - -```python - trainer_state = { - 'checkpoint_version': 1, - 'cur_step': optimizer_config.cur_step, - 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), - 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, - } - state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - with open(state_path, 'w') as f: - json.dump(trainer_state, f, indent=2) -``` - -- [ ] **Step 2: Add `resume_from_checkpoint` method** - -Insert after the `load` method (after line 867): - -```python -@remote_function(dispatch='all') -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', self._get_default_group()) - - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - self.load(checkpoint_dir, load_optimizer=not resume_only_model, - adapter_name=adapter_name, **kwargs) - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -- [ ] **Step 3: Verify `json` import exists** - -`json` is imported at line 3. No changes needed. - -- [ ] **Step 4: Commit** - -```bash -git add src/twinkle/model/megatron/megatron.py -git commit -m "feat(megatron): add resume_from_checkpoint and save trainer_state.json" -``` - ---- - -### Task 4: Add `resume_from_checkpoint` to DataLoader - -**Files:** -- Modify: `src/twinkle/dataloader/dataloader.py` (after `skip_consumed_samples`, around line 152) - -- [ ] **Step 1: Add method after `skip_consumed_samples`** - -```python -@remote_function() -def resume_from_checkpoint(self, consumed_train_samples, **kwargs): - self.skip_consumed_samples(consumed_train_samples) -``` - -- [ ] **Step 2: Commit** - -```bash -git add src/twinkle/dataloader/dataloader.py -git commit -m "feat(dataloader): add resume_from_checkpoint wrapping skip_consumed_samples" -``` - ---- - -## Chunk 2: Server, Client, Types - -### Task 5: Update Pydantic types - -**Files:** -- Modify: `src/twinkle_client/types/model.py:92-105` (request types), `231-233` (response type) - -- [ ] **Step 1: Delete `LoadTrainingStateRequest` (lines 92-97) and `ReadTrainingProgressRequest` (lines 100-105)** - -Remove both request classes. - -- [ ] **Step 2: Add `ResumeFromCheckpointRequest`** - -Insert at the same location: - -```python -class ResumeFromCheckpointRequest(BaseModel): - """Request for /resume_from_checkpoint endpoint.""" - name: str - adapter_name: str = '' - resume_only_model: bool = False -``` - -- [ ] **Step 3: Rename `TrainingProgressResponse` docstring (line 232)** - -Update the docstring from `"Response for /read_training_progress endpoint"` to `"Response for /resume_from_checkpoint endpoint"`. Keep the class name and `result` field unchanged. - -- [ ] **Step 4: Commit** - -```bash -git add src/twinkle_client/types/model.py -git commit -m "feat(types): replace training state request types with ResumeFromCheckpointRequest" -``` - ---- - -### Task 6: Update server endpoints - -**Files:** -- Modify: `src/twinkle/server/model/twinkle_handlers.py:352-402` - -- [ ] **Step 1: Delete `load_training_state` endpoint (lines 352-376)** - -Remove the entire endpoint function. - -- [ ] **Step 2: Delete `read_training_progress` endpoint (lines 378-402)** - -Remove the entire endpoint function. - -- [ ] **Step 3: Add `resume_from_checkpoint` endpoint** - -Insert at the same location, following the existing endpoint pattern: - -```python -@app.post('/twinkle/resume_from_checkpoint', response_model=types.TrainingProgressResponse) -async def resume_from_checkpoint( - request: Request, - body: types.ResumeFromCheckpointRequest, - self: ModelManagement = Depends(self_fn), -): - token = await self._on_request_start(request) - - async def _task(): - checkpoint_dir = self._resolve_checkpoint_dir(body.name) - result = self.model.resume_from_checkpoint( - checkpoint_dir, - resume_only_model=body.resume_only_model, - adapter_name=body.adapter_name or token, - ) - return types.TrainingProgressResponse(result=result) - - return await run_task(self.schedule_task_and_wait(_task, task_type='resume')) -``` - -Note: Check how `load_training_state` resolves `checkpoint_dir` from `body.name` — replicate the same pattern. If there's a `_resolve_checkpoint_dir` helper, use it. Otherwise inline the resolution logic (typically `os.path.join(output_dir, name)` or direct path). - -- [ ] **Step 4: Commit** - -```bash -git add src/twinkle/server/model/twinkle_handlers.py -git commit -m "feat(server): replace training state endpoints with /resume_from_checkpoint" -``` - ---- - -### Task 7: Update client SDK - -**Files:** -- Modify: `src/twinkle_client/model/multi_lora_transformers.py:192-208` -- Modify: `client_tools/client_generator.py:621-637` - -- [ ] **Step 1: Update `src/twinkle_client/model/multi_lora_transformers.py`** - -Delete `load_training_state` (lines 192-199) and `read_training_progress` (lines 201-208). Replace with: - -```python -def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: - response = http_post( - url=f'{self.server_url}/resume_from_checkpoint', - json_data={'name': name, 'adapter_name': self.adapter_name, - 'resume_only_model': resume_only_model, **kwargs} - ) - response.raise_for_status() - return TrainingProgressResponse(**response.json()).result -``` - -- [ ] **Step 2: Update `client_tools/client_generator.py`** - -Delete `load_training_state` (lines 621-628) and `read_training_progress` (lines 630-637). Replace with the same `resume_from_checkpoint` method as above. - -- [ ] **Step 3: Commit** - -```bash -git add src/twinkle_client/model/multi_lora_transformers.py client_tools/client_generator.py -git commit -m "feat(client): replace training state methods with resume_from_checkpoint" -``` - ---- - -## Chunk 3: Cookbook and Documentation - -### Task 8: Update cookbook examples - -**Files:** -- Modify: `cookbook/transformers/resume_utils.py:16-55` -- Modify: `cookbook/client/twinkle/self_host/self_congnition.py:102-110` - -- [ ] **Step 1: Rewrite `resume_from_checkpoint` in `cookbook/transformers/resume_utils.py`** - -The old helper function manually orchestrated model + dataloader state. Replace the function body (lines 16-55) with a simplified version that delegates to the new model API: - -```python -def resume_from_checkpoint(model, dataloader, checkpoint_path, *, resume_only_model=False, - ignore_data_skip=False, adapter_name=None) -> int: - kwargs = {} - if adapter_name: - kwargs['adapter_name'] = adapter_name - - progress = model.resume_from_checkpoint( - checkpoint_path, resume_only_model=resume_only_model, **kwargs) - - consumed_train_samples = int(progress.get('consumed_train_samples', 0)) - if not ignore_data_skip and consumed_train_samples > 0: - dataloader.resume_from_checkpoint(consumed_train_samples) - - return consumed_train_samples -``` - -This keeps the helper for backward compatibility with existing cookbook scripts that call it, but the implementation now delegates to the model's own method. - -- [ ] **Step 2: Update `cookbook/client/twinkle/self_host/self_congnition.py`** - -Replace the resume block (around lines 102-110): - -```python -# Before: -consumed_train_samples = 0 -global_step = 0 -if resume_path: - logger.info(f'Resuming model weights from {resume_path}') - model.load(resume_path) - trainer_state = model.load_training_state(resume_path) - dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) - consumed_train_samples = int(trainer_state['consumed_train_samples']) - global_step = int(trainer_state['cur_step']) -``` - -With: - -```python -consumed_train_samples = 0 -global_step = 0 -if resume_path: - logger.info(f'Resuming from checkpoint {resume_path}') - progress = model.resume_from_checkpoint(resume_path) - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - consumed_train_samples = int(progress['consumed_train_samples']) - global_step = int(progress['cur_step']) -``` - -- [ ] **Step 3: Commit** - -```bash -git add cookbook/transformers/resume_utils.py cookbook/client/twinkle/self_host/self_congnition.py -git commit -m "refactor(cookbook): use model.resume_from_checkpoint API" -``` - ---- - -### Task 9: Update documentation - -**Files:** -- Modify: `docs/source_en/Components/Model/TransformersModel.md:54-65` -- Modify: `docs/source_zh/组件/模型/TransformersModel.md:54-65` -- Modify: `docs/source_en/Usage Guide/Quick-Start.md:289-296` -- Modify: `docs/source_zh/使用指引/快速开始.md:290-297` -- Modify: `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md:141,191` -- Modify: `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md:141,181` - -- [ ] **Step 1: Update English TransformersModel.md (lines 54-65)** - -Replace the checkpoint section with: - -```markdown -### Checkpoint and Resume - -- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. -- `model.resume_from_checkpoint(checkpoint_dir)` restores full training state (weights, optimizer, scheduler, scaler, RNG) and returns `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`. -- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` loads weights only and returns progress metadata without restoring optimizer state. -- `dataloader.resume_from_checkpoint(consumed_train_samples)` skips already-consumed samples. -``` - -- [ ] **Step 2: Update Chinese TransformersModel.md (lines 54-65)** - -Mirror the English changes in Chinese: - -```markdown -### 检查点保存与续训 - -- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 保存权重、优化器、调度器、scaler、RNG 状态和 `trainer_state.json`。 -- `model.resume_from_checkpoint(checkpoint_dir)` 恢复完整训练状态(权重、优化器、调度器、scaler、RNG),返回 `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`。 -- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` 仅加载权重并返回进度元数据,不恢复优化器状态。 -- `dataloader.resume_from_checkpoint(consumed_train_samples)` 跳过已消费的样本。 -``` - -- [ ] **Step 3: Update Quick-Start docs (EN and ZH)** - -In both `docs/source_en/Usage Guide/Quick-Start.md` and `docs/source_zh/使用指引/快速开始.md`, replace `model.load_training_state(resume_path)` references with: - -```python -progress = model.resume_from_checkpoint(resume_path) -dataloader.resume_from_checkpoint(progress['consumed_train_samples']) -``` - -Update the explanatory text accordingly. - -- [ ] **Step 4: Update Twinkle-Client docs (EN and ZH)** - -In both `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md`, replace `model.load_training_state(resume_path)` references with `model.resume_from_checkpoint(resume_path)`. - -- [ ] **Step 5: Commit** - -```bash -git add docs/ -git commit -m "docs: update checkpoint/resume documentation for unified API" -``` - ---- - -### Task 10: Final grep verification - -- [ ] **Step 1: Verify no stale references remain** - -```bash -grep -rn "load_training_state\|read_training_progress" src/ cookbook/ client_tools/ docs/ --include="*.py" --include="*.md" -``` - -Expected: Only hits in `docs/superpowers/` (our spec/plan files). No hits in source code, cookbook, or user-facing docs. - -- [ ] **Step 2: Run pre-commit hooks** - -```bash -pre-commit run --all-files -``` - -Fix any formatting issues (isort, yapf, flake8). - -- [ ] **Step 3: Final commit if needed** - -```bash -git add -A -git commit -m "chore: fix formatting after resume API refactor" -``` diff --git a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md deleted file mode 100644 index a9f62b20..00000000 --- a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.md +++ /dev/null @@ -1,181 +0,0 @@ -# Unified `resume_from_checkpoint` API Design - -## Problem - -The current checkpoint resume API on the `resume_from_ckpt` branch exposes two similar methods (`load_training_state` and `read_training_progress`) that are hard to distinguish. The caller must manually orchestrate state restoration across model and dataloader, acting as a data courier between components. Additionally, the Megatron backend lacks these methods entirely, creating an asymmetric API surface. - -## Design Principle - -Each component is responsible for its own state restoration. The caller only orchestrates — it does not transport data between components. - -## Target API - -```python -progress = model.resume_from_checkpoint(checkpoint_path) -dataloader.resume_from_checkpoint(progress['consumed_train_samples']) -``` - -Two lines. Both backends. No `resume_utils.py` helper needed. - -## Return Value Contract - -`model.resume_from_checkpoint()` returns a dict with exactly these keys: - -```python -{ - 'cur_step': int, # optimizer step count - 'consumed_train_samples': int, # total samples consumed - 'gradient_accumulation_steps': int, # GAS value at save time -} -``` - -Backend-specific state (optimizer tensors, scaler, RNG, mcore sharded state) is restored internally and not exposed. - -## Component Changes - -### 1. TwinkleModel Base Class (`src/twinkle/model/base.py`) - -Add abstract method: - -```python -@abstractmethod -def resume_from_checkpoint( - self, - checkpoint_dir: str, - *, - resume_only_model: bool = False, - **kwargs, -) -> Dict[str, Any]: - ... -``` - -Parameters: -- `checkpoint_dir`: Path to the checkpoint directory. -- `resume_only_model`: If True, load weights only — skip optimizer/scheduler/RNG restoration. Useful for fine-tuning with a different optimizer config. -- `**kwargs`: Backend-specific args (e.g., `adapter_name`). - -### 2. TransformersModel (`src/twinkle/model/transformers/transformers.py`) - -Delete public methods: `load_training_state()`, `read_training_progress()`. - -Retain private helpers: `_save_training_state()`, `_load_optimizer()`, `_load_scaler_state()`, `_load_rng_state()`, `_get_training_rng_state()`. - -New implementation: - -```python -@remote_function() -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', '') - - # Load adapter weights if checkpoint contains adapter files. - has_adapter = ( - os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) - or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) - ) - if has_adapter: - self.load(checkpoint_dir, adapter_name=adapter_name) - - # Read trainer_state.json. - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - # Full restore: optimizer, scheduler, scaler, RNG. - if not resume_only_model: - optimizer_group = self._get_optimizer_group(adapter_name) - self._load_optimizer(checkpoint_dir, optimizer_group, adapter_name) - self._load_scaler_state(checkpoint_dir) - self._load_rng_state(checkpoint_dir) - optimizer_group.cur_step = trainer_state['cur_step'] - optimizer_group.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -Full-parameter training: weights are loaded at model initialization time, so `has_adapter` is False and `self.load()` is skipped. Only training state is restored. - -### 3. MegatronModel (`src/twinkle/model/megatron/megatron.py`) - -**save() change:** When `save_optimizer=True`, also write `trainer_state.json`: - -```python -if save_optimizer: - self._save_mcore_optimizer(checkpoint_dir, optimizer_config=optimizer_config, **kwargs) - trainer_state = { - 'checkpoint_version': 1, - 'cur_step': optimizer_config.cur_step, - 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), - 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, - } - state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - if self.device_mesh.rank == 0: - with open(state_path, 'w') as f: - json.dump(trainer_state, f, indent=2) -``` - -**New resume_from_checkpoint():** - -```python -@remote_function(dispatch='all') -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', self._get_default_group()) - - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - self.load(checkpoint_dir, load_optimizer=not resume_only_model, - adapter_name=adapter_name, **kwargs) - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -Megatron's `load(load_optimizer=True)` already restores optimizer/scheduler/RNG/cur_step via `_load_mcore_optimizer`. The `resume_from_checkpoint` wrapper adds `trainer_state.json` reading for `consumed_train_samples`. - -### 4. DataLoader (`src/twinkle/dataloader/dataloader.py`) - -New method: - -```python -def resume_from_checkpoint(self, consumed_train_samples, **kwargs): - self.skip_consumed_samples(consumed_train_samples) -``` - -`skip_consumed_samples` is retained as-is (not renamed) for backward compatibility. `resume_from_checkpoint` is the recommended public API going forward. - -### 5. Server Endpoints (`src/twinkle/server/model/twinkle_handlers.py`) - -- Delete: `/twinkle/load_training_state`, `/twinkle/read_training_progress` -- Add: `/twinkle/resume_from_checkpoint` accepting `checkpoint_dir` and `resume_only_model` parameters - -### 6. Client SDK (`src/twinkle_client/`, `client_tools/client_generator.py`) - -- Delete: `load_training_state()`, `read_training_progress()` client methods -- Add: `resume_from_checkpoint()` client method - -### 7. Cookbook Changes - -- Delete `resume_from_checkpoint()` helper from `cookbook/transformers/resume_utils.py` (functionality now lives in the model) -- Update all cookbook examples to use the new two-line API - -### 8. Documentation - -Update `docs/source_en/Components/Model/TransformersModel.md` and corresponding Chinese docs to reflect the new API. - -## Migration Summary - -| Before | After | -|--------|-------| -| `model.load(path)` | `progress = model.resume_from_checkpoint(path)` | -| `model.load_training_state(path)` | (merged into above) | -| `model.read_training_progress(path)` | `progress = model.resume_from_checkpoint(path, resume_only_model=True)` | -| `dataloader.skip_consumed_samples(n)` | `dataloader.resume_from_checkpoint(n)` | -| `resume_from_checkpoint(model, dataloader, ...)` (cookbook util) | Two-line inline call | diff --git a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md b/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md deleted file mode 100644 index 3323170d..00000000 --- a/docs/superpowers/specs/2026-04-21-unified-resume-api-design.zh.md +++ /dev/null @@ -1,181 +0,0 @@ -# 统一的 `resume_from_checkpoint` API 设计 - -## 问题 - -当前在 `resume_from_ckpt` 分支上的断点续训 API 暴露了两个相似的方法(`load_training_state` 和 `read_training_progress`),难以区分。调用方必须手动编排模型和数据加载器之间的状态恢复,充当组件之间的数据搬运工。此外,Megatron 后端完全没有这些方法,导致 API 表面不对称。 - -## 设计原则 - -每个组件负责自身的状态恢复。调用方只负责编排 —— 不在组件之间搬运数据。 - -## 目标 API - -```python -progress = model.resume_from_checkpoint(checkpoint_path) -dataloader.resume_from_checkpoint(progress['consumed_train_samples']) -``` - -两行代码。两个后端。不再需要 `resume_utils.py` 辅助工具。 - -## 返回值约定 - -`model.resume_from_checkpoint()` 返回一个 dict,包含以下确切的键: - -```python -{ - 'cur_step': int, # 优化器步数 - 'consumed_train_samples': int, # 已消耗的总样本数 - 'gradient_accumulation_steps': int, # 保存时的 GAS 值 -} -``` - -后端特定的状态(优化器张量、scaler、RNG、mcore 分片状态)在内部恢复,不对外暴露。 - -## 组件变更 - -### 1. TwinkleModel 基类 (`src/twinkle/model/base.py`) - -添加抽象方法: - -```python -@abstractmethod -def resume_from_checkpoint( - self, - checkpoint_dir: str, - *, - resume_only_model: bool = False, - **kwargs, -) -> Dict[str, Any]: - ... -``` - -参数说明: -- `checkpoint_dir`: 检查点目录的路径。 -- `resume_only_model`: 如果为 True,则仅加载权重 —— 跳过优化器/调度器/RNG 的恢复。适用于使用不同优化器配置进行微调的场景。 -- `**kwargs`: 后端特定的参数(例如 `adapter_name`)。 - -### 2. TransformersModel (`src/twinkle/model/transformers/transformers.py`) - -删除公共方法:`load_training_state()`、`read_training_progress()`。 - -保留私有辅助方法:`_save_training_state()`、`_load_optimizer()`、`_load_scaler_state()`、`_load_rng_state()`、`_get_training_rng_state()`。 - -新实现: - -```python -@remote_function() -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', '') - - # 如果检查点包含适配器文件,则加载适配器权重。 - has_adapter = ( - os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) - or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin')) - ) - if has_adapter: - self.load(checkpoint_dir, adapter_name=adapter_name) - - # 读取 trainer_state.json。 - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - # 完整恢复:优化器、调度器、scaler、RNG。 - if not resume_only_model: - optimizer_group = self._get_optimizer_group(adapter_name) - self._load_optimizer(checkpoint_dir, optimizer_group, adapter_name) - self._load_scaler_state(checkpoint_dir) - self._load_rng_state(checkpoint_dir) - optimizer_group.cur_step = trainer_state['cur_step'] - optimizer_group.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -全参数训练:权重在模型初始化时加载,因此 `has_adapter` 为 False,`self.load()` 被跳过。仅恢复训练状态。 - -### 3. MegatronModel (`src/twinkle/model/megatron/megatron.py`) - -**save() 变更:** 当 `save_optimizer=True` 时,同时写入 `trainer_state.json`: - -```python -if save_optimizer: - self._save_mcore_optimizer(checkpoint_dir, optimizer_config=optimizer_config, **kwargs) - trainer_state = { - 'checkpoint_version': 1, - 'cur_step': optimizer_config.cur_step, - 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), - 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, - } - state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - if self.device_mesh.rank == 0: - with open(state_path, 'w') as f: - json.dump(trainer_state, f, indent=2) -``` - -**新的 resume_from_checkpoint():** - -```python -@remote_function(dispatch='all') -def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): - adapter_name = kwargs.get('adapter_name', self._get_default_group()) - - trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') - with open(trainer_state_path, 'r') as f: - trainer_state = json.load(f) - - self.load(checkpoint_dir, load_optimizer=not resume_only_model, - adapter_name=adapter_name, **kwargs) - - return { - 'cur_step': trainer_state['cur_step'], - 'consumed_train_samples': trainer_state['consumed_train_samples'], - 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], - } -``` - -Megatron 的 `load(load_optimizer=True)` 已经通过 `_load_mcore_optimizer` 恢复了优化器/调度器/RNG/cur_step。`resume_from_checkpoint` 包装器增加了 `trainer_state.json` 的读取,以获取 `consumed_train_samples`。 - -### 4. 数据加载器 (`src/twinkle/dataloader/dataloader.py`) - -新方法: - -```python -def resume_from_checkpoint(self, consumed_train_samples, **kwargs): - self.skip_consumed_samples(consumed_train_samples) -``` - -`skip_consumed_samples` 保留原样(不更名)以保持向后兼容。`resume_from_checkpoint` 是今后推荐的公共 API。 - -### 5. 服务端接口 (`src/twinkle/server/model/twinkle_handlers.py`) - -- 删除:`/twinkle/load_training_state`、`/twinkle/read_training_progress` -- 新增:`/twinkle/resume_from_checkpoint`,接受 `checkpoint_dir` 和 `resume_only_model` 参数 - -### 6. 客户端 SDK (`src/twinkle_client/`、`client_tools/client_generator.py`) - -- 删除:`load_training_state()`、`read_training_progress()` 客户端方法 -- 新增:`resume_from_checkpoint()` 客户端方法 - -### 7. Cookbook 变更 - -- 删除 `cookbook/transformers/resume_utils.py` 中的 `resume_from_checkpoint()` 辅助函数(功能现已内置于模型中) -- 更新所有 cookbook 示例以使用新的两行 API - -### 8. 文档 - -更新 `docs/source_en/Components/Model/TransformersModel.md` 及对应的中文文档,以反映新的 API。 - -## 迁移摘要 - -| 之前 | 之后 | -|--------|-------| -| `model.load(path)` | `progress = model.resume_from_checkpoint(path)` | -| `model.load_training_state(path)` | (合并到上方) | -| `model.read_training_progress(path)` | `progress = model.resume_from_checkpoint(path, resume_only_model=True)` | -| `dataloader.skip_consumed_samples(n)` | `dataloader.resume_from_checkpoint(n)` | -| `resume_from_checkpoint(model, dataloader, ...)` (cookbook 工具函数) | 两行内联调用 | From d0219dfe40a223c4dd76a61ee20d5e51cf354d26 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 28 Apr 2026 10:25:34 +0800 Subject: [PATCH 52/60] wip --- src/twinkle/model/transformers/transformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a3e89e99..b3d1b69e 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -870,7 +870,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) if kwargs.get('save_optimizer', False): - self._save_optimizer(checkpoint_dir, adapter_name=adapter_name) self._save_training_state( checkpoint_dir, adapter_name=adapter_name, @@ -894,6 +893,8 @@ def _save_optimizer(self, output_dir, **kwargs): def _save_training_state(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + self._save_optimizer(output_dir, adapter_name=adapter_name) + optimizer_config = self.optimizer_group[adapter_name] if not Platform.is_master(): From 9d5327df794a5452267ec89a71523b4a6265a49c Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 28 Apr 2026 14:17:51 +0800 Subject: [PATCH 53/60] update --- cookbook/client/twinkle/self_host/self_cognition.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cookbook/client/twinkle/self_host/self_cognition.py b/cookbook/client/twinkle/self_host/self_cognition.py index a8010586..5d7fa666 100644 --- a/cookbook/client/twinkle/self_host/self_cognition.py +++ b/cookbook/client/twinkle/self_host/self_cognition.py @@ -99,17 +99,19 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint + start_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for _, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # Forward pass + backward pass (computes gradients) model.forward_backward(inputs=batch) @@ -126,7 +128,6 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - cur_step = dataloader.get_state()['consumed_train_samples'] // 4 if cur_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) From 85b7cf8ee32a6e99649d3778807c016c31cf0847 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 28 Apr 2026 14:46:04 +0800 Subject: [PATCH 54/60] doc --- cookbook/transformers/fsdp2_full.py | 112 ------------------ docs/source_en/Usage Guide/Quick-Start.md | 52 ++++---- .../Server and Client/Twinkle-Client.md | 32 +++-- ...53\351\200\237\345\274\200\345\247\213.md" | 52 ++++---- ...le\345\256\242\346\210\267\347\253\257.md" | 16 ++- 5 files changed, 79 insertions(+), 185 deletions(-) delete mode 100644 cookbook/transformers/fsdp2_full.py diff --git a/cookbook/transformers/fsdp2_full.py b/cookbook/transformers/fsdp2_full.py deleted file mode 100644 index 5aebf6ac..00000000 --- a/cookbook/transformers/fsdp2_full.py +++ /dev/null @@ -1,112 +0,0 @@ -from pathlib import Path - -from tqdm import tqdm - -import twinkle -from twinkle import DeviceMesh, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() - -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASET_ID = 'ms://swift/self-cognition' -TEMPLATE_NAME = 'Qwen3_5Template' -MODEL_NAME = 'twinkle大模型' -MODEL_AUTHOR = 'ModelScope社区' -FSDP_SIZE = 2 -DP_SIZE = 1 -BATCH_SIZE = 8 -LEARNING_RATE = 1e-5 -WEIGHT_DECAY = 0.01 -GRADIENT_ACCUMULATION_STEPS = 1 -LOG_INTERVAL = 1 -EVAL_INTERVAL = 20 -EVAL_SAMPLES = 100 -TRAIN_SAMPLES = 1000 - -import time -OUTPUT_DIR = f'./output/fsdp2_full_{int(time.time())}' -RESUME_FROM_CHECKPOINT = None -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False - -device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) - dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) - dataset.encode() - return dataset - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - model.save( - checkpoint_name, - output_dir=OUTPUT_DIR, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def evaluate(model): - dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) - for batch in tqdm(dataloader): - model.forward_only(inputs=batch) - model.calculate_loss() - return model.calculate_metric(is_training=False) - - -def train(): - dataset = build_dataset(TRAIN_SAMPLES) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - - model = TransformersModel(model_id=MODEL_ID) - model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - - model.set_optimizer( - optimizer_cls='AdamW', - lr=LEARNING_RATE, - weight_decay=WEIGHT_DECAY, - gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, - ) - model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - progress = model.resume_from_checkpoint(str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - - optimizer_group = model.optimizer_group[''] - best_loss = float('inf') - - for batch in dataloader: - model.forward_backward(inputs=batch) - model.clip_grad_and_step() - cur_step = optimizer_group.cur_step - if cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: - metrics = evaluate(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = cur_step - current_loss = float(metrics['loss']) - if current_loss < best_loss: - save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) - best_loss = current_loss - save_checkpoint(model, 'last-checkpoint', dataloader) - - -if __name__ == '__main__': - train() diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 9f4869b5..ef09b71c 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -232,31 +232,30 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ### Resume from Checkpoint -The local and `torchrun` training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py`. +The training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py`. -When saving a checkpoint intended for resumption, save both model weights and training progress: +**Saving a Checkpoint** ```python -def save_checkpoint(model, checkpoint_name, dataloader): - model.save( - checkpoint_name, - output_dir='./output/fsdp2', - adapter_name='default', - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) +model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name=ADAPTER_NAME, + save_optimizer=True, # Store optimizer state + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], # Persist training progress +) ``` -`save_optimizer=True` stores optimizer-related state, and `consumed_train_samples` is written into `trainer_state.json` so the dataloader can skip samples that have already been consumed. The `DataLoader` automatically tracks consumed samples internally — call `dataloader.get_state()` to retrieve the current count. +> `DataLoader` automatically tracks consumed samples internally — call `dataloader.get_state()` to retrieve the current count. -To resume training, restore the checkpoint before entering the main loop: +**Resuming Training** ```python from pathlib import Path RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False +RESUME_ONLY_MODEL = False # True: weights only, skip optimizer/scheduler restoration +IGNORE_DATA_SKIP = False # True: do not skip consumed samples from trainer_state.json if RESUME_FROM_CHECKPOINT: checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) @@ -265,26 +264,29 @@ if RESUME_FROM_CHECKPOINT: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` -This covers two common resume modes: +How the two flags combine: -- Full resume (default): restore weights, optimizer, scheduler, scaler, RNG state, and training progress, then skip consumed samples in the dataloader. -- Weights-only resume (`resume_only_model=True`): restore only model weights. This is useful when you want to continue with fresh optimizer state or intentionally restart the schedule. +| `RESUME_ONLY_MODEL` | `IGNORE_DATA_SKIP` | Effect | +|---|---|---| +| `False` (default) | `False` (default) | Full resume: restore weights + optimizer + scheduler + RNG, skip consumed data | +| `True` | `False` | Weights only, but still skip consumed data (restart optimization from fresh) | +| `True` | `True` | Weights only, restart dataset from the beginning | -When `RESUME_ONLY_MODEL=True`, `IGNORE_DATA_SKIP=False` still skips already consumed samples based on `trainer_state.json`. If you want to reload weights but restart the dataset from the beginning, set `IGNORE_DATA_SKIP=True`. +**LoRA / Adapter vs Full-Parameter Training** -The flow above is intended for LoRA / adapter training. For full-parameter training, restore model weights by passing the checkpoint path as `model_id` when constructing `TransformersModel`, instead of calling `model.load(...)`. For example: +The flow above uses LoRA as the default example. For full-parameter training, the only difference is in `TransformersModel` initialization — use the checkpoint path as `model_id` instead of the base model ID: ```python -resume_path = './output/fsdp2/last-checkpoint' +# LoRA / adapter: base model loaded from hub, checkpoint contains only adapter weights + training state +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') +progress = model.resume_from_checkpoint(resume_path) + +# Full-parameter: model weights are saved entirely in the checkpoint — use it directly as model_id model = TransformersModel(model_id=resume_path) progress = model.resume_from_checkpoint(resume_path) -dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` -In other words: - -- LoRA / adapter resume: create `TransformersModel` from the original base model, then restore via `model.resume_from_checkpoint(...)`. -- Full-parameter resume: construct `TransformersModel(...)` with the checkpoint path as `model_id`, then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. +> All subsequent calls to `resume_from_checkpoint` and `dataloader.resume_from_checkpoint` are identical in both cases. ### Ray Training diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index a4bd971d..af53f5ab 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -133,40 +133,36 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) +start_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 6: Training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # Forward propagation + backward propagation model.forward_backward(inputs=batch) - # Gradient clipping + optimizer update (equivalent to clip_grad_norm / step / zero_grad / lr_step) + # Gradient clipping + optimizer update (equivalent to calling clip_grad_norm / step / zero_grad / lr_step in sequence) model.clip_grad_and_step() - if step % 2 == 0: - logger.info(f'Step {step // 2}, loss: {output}') - - # Gradient clipping - model.clip_grad_norm(1.0) - - # Optimizer update - model.step() - - # Zero gradients - model.zero_grad() - - # Learning rate scheduling - model.lr_step() + # Print metric every 2 steps (aligned with gradient_accumulation_steps) + if cur_step % 2 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: Save checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 8: Upload to ModelScope Hub (optional) @@ -185,7 +181,7 @@ For checkpoint resumption, the recommended client-side flow is: 2. Call `model.resume_from_checkpoint(resume_path)` to restore weights, optimizer, scheduler, RNG, and progress metadata. 3. Call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip already-consumed samples. -This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_congnition.py`. +This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_cognition.py`. ## Differences with Megatron Backend diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 6dceaa1f..07596cbd 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -233,31 +233,30 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ### 断点续训 -上面的本地训练和 `torchrun` 训练循环,都可以扩展为支持断点续训。完整示例可以直接参考 `cookbook/transformers/fsdp2.py`。 +上面的训练循环可以扩展为支持断点续训。完整示例可直接参考 `cookbook/transformers/fsdp2.py`。 -如果希望保存出来的 checkpoint 可以用于续训,保存时除了模型权重,还需要把训练进度一并落盘: +**保存检查点** ```python -def save_checkpoint(model, checkpoint_name, dataloader): - model.save( - checkpoint_name, - output_dir='./output/fsdp2', - adapter_name='default', - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) +model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name=ADAPTER_NAME, + save_optimizer=True, # 保存优化器状态 + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], # 落盘训练进度 +) ``` -其中,`save_optimizer=True` 会保存优化器相关状态,`consumed_train_samples` 会写入 `trainer_state.json`,用于恢复时让 dataloader 跳过已经消费过的数据。`DataLoader` 会自动追踪已消费的样本数,通过 `dataloader.get_state()` 即可获取当前计数。 +> `DataLoader` 内部自动追踪已消费样本数,通过 `dataloader.get_state()` 获取。 -恢复训练时,建议在进入主训练循环之前先加载 checkpoint: +**恢复训练** ```python from pathlib import Path RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False +RESUME_ONLY_MODEL = False # True: 仅恢复权重,不恢复优化器/调度器等训练状态 +IGNORE_DATA_SKIP = False # True: 不从 trainer_state.json 跳过已消费数据 if RESUME_FROM_CHECKPOINT: checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) @@ -266,26 +265,29 @@ if RESUME_FROM_CHECKPOINT: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` -这两行 API 覆盖了两种常见恢复模式: +两个开关的组合效果: -- 完整续训(默认):恢复权重、优化器、学习率调度器、梯度缩放器、随机数状态和训练进度,并让 dataloader 跳过已消费样本。 -- 仅恢复权重(`resume_only_model=True`):只加载模型权重,不恢复优化器等训练状态。适合希望沿用参数初始化、但重新开始优化过程的场景。 +| `RESUME_ONLY_MODEL` | `IGNORE_DATA_SKIP` | 效果 | +|---|---|---| +| `False`(默认) | `False`(默认) | 完整续训:恢复权重 + 优化器 + 调度器 + RNG,并跳过已消费数据 | +| `True` | `False` | 仅恢复权重,但仍跳过已消费数据(适合沿用权重、重新开始优化) | +| `True` | `True` | 仅恢复权重,从数据集开头重新训练 | -当 `RESUME_ONLY_MODEL=True` 且 `IGNORE_DATA_SKIP=False` 时,仍会根据 `trainer_state.json` 跳过已训练过的数据;如果你只想加载权重、但从数据集开头重新训练,可以把 `IGNORE_DATA_SKIP=True`。 +**LoRA / adapter vs 全参训练** -上面的恢复流程默认针对 LoRA / adapter 训练。对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 `model_id` 设为 checkpoint 路径,而不是再调用 `model.load(...)`。例如: +上述流程默认以 LoRA 为例。全参训练的恢复仅有一处不同——`TransformersModel` 初始化时,`model_id` 需要用 checkpoint 路径替代 base model ID: ```python -resume_path = './output/fsdp2/last-checkpoint' +# LoRA / adapter:base model 从 hub 加载,checkpoint 仅含 adapter 权重和训练状态 +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') +progress = model.resume_from_checkpoint(resume_path) + +# 全参:模型权重已整体保存到 checkpoint,直接将其作为 model_id model = TransformersModel(model_id=resume_path) progress = model.resume_from_checkpoint(resume_path) -dataloader.resume_from_checkpoint(progress['consumed_train_samples']) ``` -也就是说: - -- LoRA / adapter 续训:先按原始 base model 创建 `TransformersModel`,再通过 `model.resume_from_checkpoint(...)` 恢复。 -- 全参续训:在 `TransformersModel(...)` 初始化时直接传入 checkpoint 路径作为 `model_id`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 +> 二者后续的 `resume_from_checkpoint` 及 `dataloader.resume_from_checkpoint` 调用完全一致。 ### Ray训练 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index e40ce7e1..967479b9 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -133,17 +133,19 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) +start_step = 0 if resume_path: logger.info(f'Resuming from checkpoint {resume_path}') progress = model.resume_from_checkpoint(resume_path) dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 6: 训练循环 logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # 前向传播 + 反向传播 model.forward_backward(inputs=batch) @@ -151,12 +153,16 @@ for epoch in range(3): model.clip_grad_and_step() # 每 2 步打印一次指标(与 gradient_accumulation_steps 对齐) - if step % 2 == 0: + if cur_step % 2 == 0: metric = model.calculate_metric(is_training=True) - logger.info(f'Epoch {epoch}, step {step}/{len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: 保存检查点 - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 8: 上传到 ModelScope Hub(可选) @@ -175,7 +181,7 @@ Twinkle Client 场景下,推荐的断点续训流程是: 2. 调用 `model.resume_from_checkpoint(resume_path)` 恢复权重、优化器、调度器、随机数状态和训练进度元数据。 3. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.resume_from_checkpoint(...)`,跳过已经训练过的数据。 -完整示例可直接参考 `cookbook/client/twinkle/self_host/self_congnition.py`。 +完整示例可直接参考 `cookbook/client/twinkle/self_host/self_cognition.py`。 ## Megatron 后端的差异 From 9af73bc22156073e4cdc5583c19a9bcdf8d88a2a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 29 Apr 2026 09:23:56 +0800 Subject: [PATCH 55/60] fix --- src/twinkle/model/megatron/multi_lora_megatron.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 4f67a450..140a1373 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -213,7 +213,7 @@ def _rank_local_optimizer_path(checkpoint_dir: str) -> str: return os.path.join(checkpoint_dir, f'optimizer_rank_{rank}.pt') @staticmethod - def _get_local_training_rng_state(): + def _save_local_training_rng_state(): from megatron.core import tensor_parallel rng_state = { @@ -227,7 +227,7 @@ def _get_local_training_rng_state(): return rng_state @staticmethod - def _restore_local_training_rng_state(rng_state): + def _load_local_training_rng_state(rng_state): from megatron.core import tensor_parallel random.setstate(rng_state['random_rng_state']) @@ -242,7 +242,7 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw state_dict = { 'checkpoint_version': 1, 'iteration': optimizer_config.cur_step, - 'rng_state': self._get_local_training_rng_state(), + 'rng_state': self._save_local_training_rng_state(), } if optimizer_config.optimizer is not None: state_dict['optimizer'] = optimizer_config.optimizer.state_dict() @@ -263,7 +263,7 @@ def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '' if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) if not no_load_rng and 'rng_state' in state_dict: - self._restore_local_training_rng_state(state_dict['rng_state']) + self._load_local_training_rng_state(state_dict['rng_state']) if optimizer_config is not None and 'iteration' in state_dict: optimizer_config.cur_step = state_dict['iteration'] From 482a451dbc88dd90f1bbbf4bbd769e6fea02fe77 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 30 Apr 2026 15:04:11 +0800 Subject: [PATCH 56/60] fix doc --- docs/source_en/Usage Guide/Quick-Start.md | 2 +- .../\345\277\253\351\200\237\345\274\200\345\247\213.md" | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index ef09b71c..185b91ac 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -471,7 +471,7 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. -Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See `docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md` and `cookbook/client/twinkle/self_host/self_congnition.py`. +Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py). Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 07596cbd..3f19af5b 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -470,7 +470,7 @@ python train.py ``` ### 远程训练 -client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 `docs/source_zh/使用指引/服务端和客户端/Twinkle客户端.md` 和 `cookbook/client/twinkle/self_host/self_congnition.py`。 +client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py)。 Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。 From 239641992a768f490598292527c12d58381cd5df Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 30 Apr 2026 15:57:33 +0800 Subject: [PATCH 57/60] fix --- src/twinkle/model/megatron/megatron.py | 19 +++++++++--------- .../transformers/multi_lora_transformers.py | 2 +- src/twinkle/server/model/twinkle_handlers.py | 1 - tests/dataloader/test_dataloader.py | 19 ++++++++---------- tests/dataloader/test_sampler.py | 20 ++++++++----------- 5 files changed, 26 insertions(+), 35 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 9fa3a46e..d6e32756 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -860,17 +860,16 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): ``no_load_rng``, etc.). """ resume = kwargs.pop('load_optimizer', False) - if output_dir is None and not resume: - if os.path.exists(name): - checkpoint_dir = name - else: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) - else: - if output_dir is None: - output_dir = 'output' + if output_dir is not None: checkpoint_dir = os.path.join(output_dir, name) + elif os.path.exists(name): + checkpoint_dir = name + elif not resume: + # load from hub + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) + else: + checkpoint_dir = os.path.join('output', name) adapter_name = kwargs.pop('adapter_name', self._get_default_group()) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 2c1ae9a9..cedd7af6 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -235,7 +235,7 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): return super().save(name, output_dir, interval, **kwargs) @remote_function() - def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs): + def load(self, name: str, output_dir: Optional[str] = None, **kwargs): adapter_name = kwargs.get('adapter_name') self._check_adapter_valid(adapter_name) with self.multi_adapter.save_context(kwargs.get('adapter_name')): diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 3989fed9..b9d3e295 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -9,7 +9,6 @@ from __future__ import annotations import asyncio -import os import torch import traceback from fastapi import Depends, FastAPI, HTTPException, Request diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index edad0dd3..bb04b62b 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -6,27 +6,24 @@ from pathlib import Path from torch.utils.data import Dataset as TorchDataset from torch.utils.data import IterableDataset as TorchIterableDataset +from unittest.mock import MagicMock import twinkle +import twinkle.hub.hub as _hub_module from twinkle import DeviceMesh from twinkle.data_format import Message from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta, IterableDataset from twinkle.processor import InputProcessor +twinkle.initialize(mode='local') -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -twinkle.initialize(mode='local') +@pytest.fixture(autouse=True) +def _disable_process_pool(monkeypatch): + mock_executor = MagicMock() + mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') + monkeypatch.setattr(_hub_module, '_executor', mock_executor) TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true' diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index d90c8725..1c010808 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,30 +1,26 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import concurrent.futures import numpy as np import os import pytest from pathlib import Path from torch.utils.data import Dataset as TorchDataset from torch.utils.data import RandomSampler, SequentialSampler +from unittest.mock import MagicMock import twinkle +import twinkle.hub.hub as _hub_module from twinkle import DeviceMesh from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta +twinkle.initialize(mode='local') -class _NoOpProcessPoolExecutor: - - def __init__(self, *args, **kwargs): - pass - - def submit(self, fn, *args, **kwargs): - raise RuntimeError('Process pool is disabled in this test environment.') - - -concurrent.futures.ProcessPoolExecutor = _NoOpProcessPoolExecutor -twinkle.initialize(mode='local') +@pytest.fixture(autouse=True) +def _disable_process_pool(monkeypatch): + mock_executor = MagicMock() + mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') + monkeypatch.setattr(_hub_module, '_executor', mock_executor) TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' From 9a6fbb980ffa4b59d2ca83ffa1a832144c653e33 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 30 Apr 2026 16:01:15 +0800 Subject: [PATCH 58/60] lint --- tests/dataloader/test_dataloader.py | 1 + tests/dataloader/test_sampler.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index bb04b62b..2da0a4f8 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -25,6 +25,7 @@ def _disable_process_pool(monkeypatch): mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') monkeypatch.setattr(_hub_module, '_executor', mock_executor) + TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true' diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index 1c010808..3b7a4ebc 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -22,6 +22,7 @@ def _disable_process_pool(monkeypatch): mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') monkeypatch.setattr(_hub_module, '_executor', mock_executor) + TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' From daa9202cd8e990b96f53eeebf3f6b6704b718190 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 30 Apr 2026 16:06:58 +0800 Subject: [PATCH 59/60] update cookbook --- cookbook/megatron/tp.py | 122 ++++++++++++++++++++------------- cookbook/megatron/tp_resume.py | 113 ------------------------------ 2 files changed, 76 insertions(+), 159 deletions(-) delete mode 100644 cookbook/megatron/tp_resume.py diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 4ea2e13e..214e58fa 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -1,81 +1,111 @@ -import os +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, tp=pp=dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +DP_SIZE = 2 +TP_SIZE = 2 +PP_SIZE = 2 +BATCH_SIZE = 16 +LEARNING_RATE = 1e-4 +LOG_INTERVAL = 5 +EVAL_INTERVAL = 20 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +OUTPUT_DIR = './output/megatron_tp' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=16) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() - # Global batch size = 1, dp_size = 1 - dataloader = DataLoader(dataset=dataset, batch_size=16) - # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') + dataset = build_dataset(TRAIN_SAMPLES) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + model = MegatronModel(model_id=MODEL_ID) lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') - # Add a lora to model, with name `default` # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='default', lr=1e-4) - # Add LRScheduler for lora `default` + model.add_adapter_to_model(ADAPTER_NAME, lora_config) + model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + logger.info(get_device_placement()) - # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 - # lora: 10G * 8 - # full: 40G * 8 + + best_loss = float('inf') + for step, batch in enumerate(dataloader): - # Do forward and backward model.forward_backward(inputs=batch) - # Step model.clip_grad_and_step() - if step % 5 == 0: - # Print metric + if step % LOG_INTERVAL == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 20 == 0: - metrics = eval(model) + if step > 0 and step % EVAL_INTERVAL == 0: + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) if __name__ == '__main__': diff --git a/cookbook/megatron/tp_resume.py b/cookbook/megatron/tp_resume.py deleted file mode 100644 index a099789e..00000000 --- a/cookbook/megatron/tp_resume.py +++ /dev/null @@ -1,113 +0,0 @@ -from pathlib import Path - -from peft import LoraConfig -from tqdm import tqdm - -import twinkle -from twinkle import DeviceMesh, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import MegatronModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() - -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASET_ID = 'ms://swift/self-cognition' -TEMPLATE_NAME = 'Qwen3_5Template' -MODEL_NAME = 'twinkle大模型' -MODEL_AUTHOR = 'ModelScope社区' -DP_SIZE = 2 -TP_SIZE = 2 -PP_SIZE = 2 -BATCH_SIZE = 16 -LEARNING_RATE = 1e-4 -LOG_INTERVAL = 5 -EVAL_INTERVAL = 20 -EVAL_SAMPLES = 100 -TRAIN_SAMPLES = 1000 - -OUTPUT_DIR = './output/megatron_tp' -RESUME_FROM_CHECKPOINT = None -RESUME_ONLY_MODEL = False -IGNORE_DATA_SKIP = False -ADAPTER_NAME = 'default' - -device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) - dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) - dataset.encode() - return dataset - - -def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): - model.save( - checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def evaluate(model): - dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) - for batch in tqdm(dataloader): - model.forward_only(inputs=batch) - return model.calculate_metric(is_training=False) - - -def train(): - dataset = build_dataset(TRAIN_SAMPLES) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - - model = MegatronModel(model_id=MODEL_ID) - - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') - - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model(ADAPTER_NAME, lora_config) - model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) - model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - - best_loss = float('inf') - - for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch) - model.clip_grad_and_step() - if step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % EVAL_INTERVAL == 0: - metrics = evaluate(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - current_loss = float(metrics['loss']) - if current_loss < best_loss: - save_checkpoint(model, f'checkpoint-{step}', dataloader) - best_loss = current_loss - save_checkpoint(model, 'last-checkpoint', dataloader) - - -if __name__ == '__main__': - train() From a75f8b1d71b992056f304288bb203d96893993dc Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 30 Apr 2026 17:22:48 +0800 Subject: [PATCH 60/60] fix --- cookbook/megatron/tp.py | 4 +++- src/twinkle/model/megatron/megatron.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 214e58fa..650cf67b 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -75,6 +75,7 @@ def train(): model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + start_step = 0 if RESUME_FROM_CHECKPOINT: checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() kwargs = {} @@ -84,6 +85,7 @@ def train(): str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) if not IGNORE_DATA_SKIP: dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] logger.info(get_device_placement()) logger.info(model.get_train_configs()) @@ -91,7 +93,7 @@ def train(): best_loss = float('inf') - for step, batch in enumerate(dataloader): + for step, batch in enumerate(dataloader, start=start_step): model.forward_backward(inputs=batch) model.clip_grad_and_step() if step % LOG_INTERVAL == 0: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index d6e32756..f61e66ab 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1117,7 +1117,8 @@ def _load_mcore_optimizer( # Restore optimizer + LR scheduler. if not no_load_optim and optimizer is not None and 'optimizer' in state_dict: - optimizer.load_state_dict(state_dict['optimizer']) + with torch.no_grad(): + optimizer.load_state_dict(state_dict['optimizer']) if (opt_param_scheduler is not None and 'opt_param_scheduler' in state_dict): opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'], )