Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, torch, importlib
import os, json, torch, importlib
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
Expand All @@ -16,6 +16,25 @@ def get_optimizer_class(customized_optimizer=None):
return getattr(module, class_name)


def save_training_args(accelerator: Accelerator, model_logger: ModelLogger, args):
"""Dump the parsed training arguments to ``output_path/training_args.json`` for reproducibility.

Saving is best-effort: a failure here (a non-serializable value, a permission error, a full disk, ...)
should never interrupt training, so any exception is caught and reported as a warning. ``default=str``
keeps the dump robust if a future argument holds a non-JSON-serializable value.
"""
if args is None or not accelerator.is_main_process:
return
try:
os.makedirs(model_logger.output_path, exist_ok=True)
save_path = os.path.join(model_logger.output_path, "training_args.json")
with open(save_path, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4, ensure_ascii=False, default=str)
print(f"Training arguments saved to `{save_path}`.")
except Exception as e:
print(f"Warning: failed to save training arguments: {e}")


def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
Expand Down Expand Up @@ -44,6 +63,8 @@ def launch_training_task(
cpu_offload_split_threshold = args.cpu_offload_split_threshold
customized_optimizer = args.customized_optimizer

save_training_args(accelerator, model_logger, args)

optimizer_class = get_optimizer_class(customized_optimizer)
optimizer = optimizer_class(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
Expand Down