From cff966be88755b84875e969f8691583da7433393 Mon Sep 17 00:00:00 2001 From: Tien Nguyen Date: Thu, 23 Apr 2026 11:46:22 +0700 Subject: [PATCH 1/2] feat: add resume_from_checkpoint option to training configuration --- diffsynth/diffusion/parsers.py | 5 ++ diffsynth/diffusion/runner.py | 86 +++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index b8c6c6afd..7c4fb8d87 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -37,6 +37,11 @@ def add_training_config(parser: argparse.ArgumentParser): parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + ) return parser def add_output_config(parser: argparse.ArgumentParser): diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 43a8d3766..6c89443e2 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -15,7 +15,7 @@ def launch_training_task( num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, - args = None, + args=None, ): if args is not None: learning_rate = args.learning_rate @@ -23,27 +23,99 @@ def launch_training_task( num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs - - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + resume_from_checkpoint = args.resume_from_checkpoint + + optimizer = torch.optim.AdamW( + model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay + ) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + collate_fn=lambda x: x[0], + num_workers=num_workers, + ) + model.to(device=accelerator.device) - model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + model, optimizer, dataloader, scheduler = accelerator.prepare( + model, optimizer, dataloader, scheduler + ) + initialize_deepspeed_gradient_checkpointing(accelerator) - for epoch_id in range(num_epochs): - for data in tqdm(dataloader): + + # Resume from checkpoints if specified + global_step = 0 + + if resume_from_checkpoint: + ckpt_path = args.resume_from_checkpoint + accelerator.print(f"Resuming from {ckpt_path}") + accelerator.load_state(ckpt_path) + + # recover step number from folder name + try: + global_step = int(os.path.basename(ckpt_path).split("-")[-1]) + except: + global_step = 0 + + # compute steps per epoch after dataloader is built + steps_per_epoch = len(dataloader) + + start_epoch = 0 + resume_step = 0 + if global_step > 0: + start_epoch = global_step // steps_per_epoch + resume_step = global_step % steps_per_epoch + + accelerator.print( + f"Resuming at global_step={global_step}, " + f"start_epoch={start_epoch}, resume_step={resume_step}" + ) + + # Training loop + for epoch_id in range(start_epoch, num_epochs): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + disable=not accelerator.is_local_main_process + ) + + for step, data in progress_bar: + + # skip already-trained steps + if epoch_id == start_epoch and step < resume_step: + continue + with accelerator.accumulate(model): if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) + accelerator.backward(loss) optimizer.step() scheduler.step() optimizer.zero_grad() model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + + global_step += 1 + + # existing logging (LoRA weights etc.) + model_logger.on_step_end( + accelerator, model, save_steps, loss=loss + ) + + # Full checkpoint save + if save_steps is not None and global_step % save_steps == 0: + save_dir = os.path.join( + args.output_path, f"checkpoint-{global_step}" + ) + accelerator.save_state(save_dir) + accelerator.print(f"Saved checkpoint to {save_dir}") + if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) From 7f6d59728840961521dd3eb46695d8578be1e696 Mon Sep 17 00:00:00 2001 From: Tien Nguyen Date: Thu, 23 Apr 2026 13:26:41 +0700 Subject: [PATCH 2/2] chore: refactor --- diffsynth/diffusion/parsers.py | 1 + diffsynth/diffusion/runner.py | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index 7c4fb8d87..6e1b4d3d8 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -41,6 +41,7 @@ def add_training_config(parser: argparse.ArgumentParser): "--resume_from_checkpoint", type=str, default=None, + help="Path to a checkpoint directory to resume training from." ) return parser diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 6c89443e2..1652ee783 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -17,6 +17,7 @@ def launch_training_task( num_epochs: int = 1, args=None, ): + resume_from_checkpoint = None if args is not None: learning_rate = args.learning_rate weight_decay = args.weight_decay @@ -58,6 +59,8 @@ def launch_training_task( except: global_step = 0 + model_logger.num_steps = global_step + # compute steps per epoch after dataloader is built steps_per_epoch = len(dataloader) @@ -100,15 +103,10 @@ def launch_training_task( global_step += 1 - # existing logging (LoRA weights etc.) - model_logger.on_step_end( - accelerator, model, save_steps, loss=loss - ) - # Full checkpoint save if save_steps is not None and global_step % save_steps == 0: save_dir = os.path.join( - args.output_path, f"checkpoint-{global_step}" + model_logger.output_path, f"checkpoint-{global_step}" ) accelerator.save_state(save_dir) accelerator.print(f"Saved checkpoint to {save_dir}")