Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions diffsynth/diffusion/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Path to a checkpoint directory to resume training from."
)
Comment on lines +40 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new --resume_from_checkpoint argument is missing a help description. Adding one improves the usability of the CLI by explaining what the parameter expects.

Suggested change
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Path to a checkpoint directory to resume training from."
)

return parser

def add_output_config(parser: argparse.ArgumentParser):
Expand Down
84 changes: 77 additions & 7 deletions diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,105 @@ def launch_training_task(
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
args=None,
):
resume_from_checkpoint = None
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs

optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
resume_from_checkpoint = args.resume_from_checkpoint
Comment on lines +18 to +27
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable resume_from_checkpoint is only defined if args is not None. If args is None, accessing it later at line 50 will raise an UnboundLocalError. It should be initialized to None at the start of the function to ensure safety.

Suggested change
args=None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
resume_from_checkpoint = args.resume_from_checkpoint
args=None,
):
resume_from_checkpoint = None
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
resume_from_checkpoint = args.resume_from_checkpoint


optimizer = torch.optim.AdamW(
model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)

dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
collate_fn=lambda x: x[0],
num_workers=num_workers,
)

model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)

initialize_deepspeed_gradient_checkpointing(accelerator)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):

# Resume from checkpoints if specified
global_step = 0

if resume_from_checkpoint:
ckpt_path = args.resume_from_checkpoint
accelerator.print(f"Resuming from {ckpt_path}")
accelerator.load_state(ckpt_path)

# recover step number from folder name
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except:
global_step = 0
Comment on lines +57 to +60
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Avoid using a bare except: block as it can catch unexpected exceptions like KeyboardInterrupt. Additionally, when resuming from a checkpoint, the model_logger.num_steps should be synchronized with the recovered global_step to ensure that subsequent logging and checkpoint filenames (e.g., step-1001.safetensors) are correct.

Suggested change
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except:
global_step = 0
try:
global_step = int(os.path.basename(ckpt_path).split("-")[-1])
except Exception:
global_step = 0
model_logger.num_steps = global_step


model_logger.num_steps = global_step

# compute steps per epoch after dataloader is built
steps_per_epoch = len(dataloader)

start_epoch = 0
resume_step = 0
if global_step > 0:
start_epoch = global_step // steps_per_epoch
resume_step = global_step % steps_per_epoch

accelerator.print(
f"Resuming at global_step={global_step}, "
f"start_epoch={start_epoch}, resume_step={resume_step}"
)

# Training loop
for epoch_id in range(start_epoch, num_epochs):
progress_bar = tqdm(
enumerate(dataloader),
total=len(dataloader),
disable=not accelerator.is_local_main_process
)

for step, data in progress_bar:

# skip already-trained steps
if epoch_id == start_epoch and step < resume_step:
continue

with accelerator.accumulate(model):
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)

accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)

global_step += 1

# Full checkpoint save
if save_steps is not None and global_step % save_steps == 0:
save_dir = os.path.join(
model_logger.output_path, f"checkpoint-{global_step}"
)
Comment on lines +108 to +110
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing args.output_path directly here is unsafe if args is None. Since model_logger is already initialized with the output path, it is safer and more consistent to use model_logger.output_path for constructing the checkpoint directory path.

Suggested change
save_dir = os.path.join(
args.output_path, f"checkpoint-{global_step}"
)
save_dir = os.path.join(
model_logger.output_path, f"checkpoint-{global_step}"
)

accelerator.save_state(save_dir)
accelerator.print(f"Saved checkpoint to {save_dir}")

if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)

model_logger.on_training_end(accelerator, model, save_steps)


Expand Down