Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a comprehensive "Strict Resume" feature for Transformers models, enabling the restoration of full training state including optimizer, scheduler, scaler, RNG states, and data progress. Key changes involve implementing load_training_state and read_training_progress across the model, server, and client layers, alongside dataloader enhancements to support sample-level skipping for map-style datasets. Feedback highlights several critical improvements: ensuring deterministic RNG in distributed settings by avoiding unseeded random states, addressing the deprecated use of StopIteration in generators, improving security by using weights_only=True during checkpoint loading, and removing an accidental BOM character in the client generator. Additionally, a more robust approach for re-initializing the dataloader is suggested to avoid modifying private PyTorch attributes.
| self.dataloader.__initialized = False | ||
| self._rebuild_sampler_stack() | ||
| self.dataloader.__initialized = True |
There was a problem hiding this comment.
Accessing and modifying the private attribute __initialized of torch.utils.data.DataLoader is brittle and relies on internal implementation details of PyTorch that could change. A safer approach to update the sampler stack after the dataloader has been created is to simply re-instantiate the underlying self.dataloader using the stored self.dataloader_params.
| self.dataloader.__initialized = False | |
| self._rebuild_sampler_stack() | |
| self.dataloader.__initialized = True | |
| if self.dataloader is not None: | |
| self.dataloader = None | |
| self._lazy_init_dataloader() |
|
/gemini summary |
| ) | ||
| response.raise_for_status() | ||
|
|
||
| def load_training_state(self, name: str, **kwargs) -> Dict[str, Any]: |
There was a problem hiding this comment.
load_training_state和read_training_progress什么区别,能否合并为一个呢
| twinkle_path = model.save( | ||
| name=f'twinkle-epoch-{epoch}', | ||
| save_optimizer=True, | ||
| consumed_train_samples=consumed_train_samples, |
There was a problem hiding this comment.
dataloader.get_consumed_samples()?
There was a problem hiding this comment.
或者,dataloader.get_state(),更通用一些
There was a problem hiding this comment.
另外,这里额外测试下torchrun/ray的兼容性,还有megatron和transformers双模型的兼容性
| adapter_name = kwargs.pop('adapter_name', _default_adapter_name) | ||
| optimizer_config = self.optimizer_group[adapter_name] | ||
|
|
||
| if not Platform.is_master(): |
There was a problem hiding this comment.
这里ray和torchrun都需要确保正确,megatron部分也需要对应考虑
…s with resume_from_checkpoint
…lti_lora, and docs
…sumeFromCheckpointRequest
PR type
PR information
在TrasnfomersModel和MultiLoraModel实现完整训练状态的恢复——包括优化器、调度器、RNG配置以及数据集跳过