Skip to content

feat: add resume_from_checkpoint option to training configuration#1406

Open
hoangvictor wants to merge 2 commits intomodelscope:mainfrom
hoangvictor:main
Open

feat: add resume_from_checkpoint option to training configuration#1406
hoangvictor wants to merge 2 commits intomodelscope:mainfrom
hoangvictor:main

Conversation

@hoangvictor
Copy link
Copy Markdown

Summary

Adds support for resuming training from a saved checkpoint via a new resume_from_checkpoint configuration option.

Rationale

This enables restoring model, optimizer, and scheduler states and continuing training from a specific step instead of restarting from scratch.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ability to resume training from a checkpoint by adding a new CLI argument and updating the training runner logic. The changes include loading the accelerator state, calculating the correct starting epoch and step based on the checkpoint name, and implementing full checkpoint saving. The review feedback identifies several critical issues: a potential UnboundLocalError if arguments are not provided, a redundant logging call that would cause incorrect step counting, unsafe access to the output path, and the need for better exception handling and CLI documentation.

Comment on lines +18 to +26
args=None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs

optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
resume_from_checkpoint = args.resume_from_checkpoint
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable resume_from_checkpoint is only defined if args is not None. If args is None, accessing it later at line 50 will raise an UnboundLocalError. It should be initialized to None at the start of the function to ensure safety.

Suggested change
args=None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
resume_from_checkpoint = args.resume_from_checkpoint
args=None,
):
resume_from_checkpoint = None
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
resume_from_checkpoint = args.resume_from_checkpoint

Comment on lines +56 to +59
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except:
global_step = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Avoid using a bare except: block as it can catch unexpected exceptions like KeyboardInterrupt. Additionally, when resuming from a checkpoint, the model_logger.num_steps should be synchronized with the recovered global_step to ensure that subsequent logging and checkpoint filenames (e.g., step-1001.safetensors) are correct.

Suggested change
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except:
global_step = 0
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except Exception:
global_step = 0
model_logger.num_steps = global_step

Comment thread diffsynth/diffusion/runner.py Outdated
Comment on lines +101 to +106
global_step += 1

# existing logging (LoRA weights etc.)
model_logger.on_step_end(
accelerator, model, save_steps, loss=loss
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to model_logger.on_step_end at line 104 is redundant because it is already called inside the accelerator.accumulate block at line 99. Calling it twice per batch will cause the logger's internal step counter to increment twice as fast as intended, leading to incorrect checkpointing intervals and filenames.

Suggested change
global_step += 1
# existing logging (LoRA weights etc.)
model_logger.on_step_end(
accelerator, model, save_steps, loss=loss
)
global_step += 1

Comment on lines +40 to +44
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new --resume_from_checkpoint argument is missing a help description. Adding one improves the usability of the CLI by explaining what the parameter expects.

Suggested change
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Path to a checkpoint directory to resume training from."
)

Comment on lines +110 to +112
save_dir = os.path.join(
args.output_path, f"checkpoint-{global_step}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing args.output_path directly here is unsafe if args is None. Since model_logger is already initialized with the output path, it is safer and more consistent to use model_logger.output_path for constructing the checkpoint directory path.

Suggested change
save_dir = os.path.join(
args.output_path, f"checkpoint-{global_step}"
)
save_dir = os.path.join(
model_logger.output_path, f"checkpoint-{global_step}"
)

@hoangvictor
Copy link
Copy Markdown
Author

@Artiprocher Could you review this PR? I’ve added a checkpoint resume feature so training can be continued from any saved checkpoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant