feat: add resume_from_checkpoint option to training configuration#1406
feat: add resume_from_checkpoint option to training configuration#1406hoangvictor wants to merge 2 commits intomodelscope:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the ability to resume training from a checkpoint by adding a new CLI argument and updating the training runner logic. The changes include loading the accelerator state, calculating the correct starting epoch and step based on the checkpoint name, and implementing full checkpoint saving. The review feedback identifies several critical issues: a potential UnboundLocalError if arguments are not provided, a redundant logging call that would cause incorrect step counting, unsafe access to the output path, and the need for better exception handling and CLI documentation.
| args=None, | ||
| ): | ||
| if args is not None: | ||
| learning_rate = args.learning_rate | ||
| weight_decay = args.weight_decay | ||
| num_workers = args.dataset_num_workers | ||
| save_steps = args.save_steps | ||
| num_epochs = args.num_epochs | ||
|
|
||
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) | ||
| resume_from_checkpoint = args.resume_from_checkpoint |
There was a problem hiding this comment.
The variable resume_from_checkpoint is only defined if args is not None. If args is None, accessing it later at line 50 will raise an UnboundLocalError. It should be initialized to None at the start of the function to ensure safety.
| args=None, | |
| ): | |
| if args is not None: | |
| learning_rate = args.learning_rate | |
| weight_decay = args.weight_decay | |
| num_workers = args.dataset_num_workers | |
| save_steps = args.save_steps | |
| num_epochs = args.num_epochs | |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) | |
| resume_from_checkpoint = args.resume_from_checkpoint | |
| args=None, | |
| ): | |
| resume_from_checkpoint = None | |
| if args is not None: | |
| learning_rate = args.learning_rate | |
| weight_decay = args.weight_decay | |
| num_workers = args.dataset_num_workers | |
| save_steps = args.save_steps | |
| num_epochs = args.num_epochs | |
| resume_from_checkpoint = args.resume_from_checkpoint |
| try: | ||
| global_step = int(os.path.basename(ckpt_path).split("-")[-1]) | ||
| except: | ||
| global_step = 0 |
There was a problem hiding this comment.
Avoid using a bare except: block as it can catch unexpected exceptions like KeyboardInterrupt. Additionally, when resuming from a checkpoint, the model_logger.num_steps should be synchronized with the recovered global_step to ensure that subsequent logging and checkpoint filenames (e.g., step-1001.safetensors) are correct.
| try: | |
| global_step = int(os.path.basename(ckpt_path).split("-")[-1]) | |
| except: | |
| global_step = 0 | |
| try: | |
| global_step = int(os.path.basename(ckpt_path).split("-")[-1]) | |
| except Exception: | |
| global_step = 0 | |
| model_logger.num_steps = global_step |
| global_step += 1 | ||
|
|
||
| # existing logging (LoRA weights etc.) | ||
| model_logger.on_step_end( | ||
| accelerator, model, save_steps, loss=loss | ||
| ) |
There was a problem hiding this comment.
The call to model_logger.on_step_end at line 104 is redundant because it is already called inside the accelerator.accumulate block at line 99. Calling it twice per batch will cause the logger's internal step counter to increment twice as fast as intended, leading to incorrect checkpointing intervals and filenames.
| global_step += 1 | |
| # existing logging (LoRA weights etc.) | |
| model_logger.on_step_end( | |
| accelerator, model, save_steps, loss=loss | |
| ) | |
| global_step += 1 |
| parser.add_argument( | ||
| "--resume_from_checkpoint", | ||
| type=str, | ||
| default=None, | ||
| ) |
There was a problem hiding this comment.
The new --resume_from_checkpoint argument is missing a help description. Adding one improves the usability of the CLI by explaining what the parameter expects.
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| default=None, | |
| help="Path to a checkpoint directory to resume training from." | |
| ) |
| save_dir = os.path.join( | ||
| args.output_path, f"checkpoint-{global_step}" | ||
| ) |
There was a problem hiding this comment.
Accessing args.output_path directly here is unsafe if args is None. Since model_logger is already initialized with the output path, it is safer and more consistent to use model_logger.output_path for constructing the checkpoint directory path.
| save_dir = os.path.join( | |
| args.output_path, f"checkpoint-{global_step}" | |
| ) | |
| save_dir = os.path.join( | |
| model_logger.output_path, f"checkpoint-{global_step}" | |
| ) |
|
@Artiprocher Could you review this PR? I’ve added a checkpoint resume feature so training can be continued from any saved checkpoint. |
Summary
Adds support for resuming training from a saved checkpoint via a new
resume_from_checkpointconfiguration option.Rationale
This enables restoring model, optimizer, and scheduler states and continuing training from a specific step instead of restarting from scratch.