diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 38e81092..342d9472 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -1,4 +1,4 @@ -import os, torch, importlib +import os, json, torch, importlib from tqdm import tqdm from accelerate import Accelerator from .training_module import DiffusionTrainingModule @@ -16,6 +16,25 @@ def get_optimizer_class(customized_optimizer=None): return getattr(module, class_name) +def save_training_args(accelerator: Accelerator, model_logger: ModelLogger, args): + """Dump the parsed training arguments to ``output_path/training_args.json`` for reproducibility. + + Saving is best-effort: a failure here (a non-serializable value, a permission error, a full disk, ...) + should never interrupt training, so any exception is caught and reported as a warning. ``default=str`` + keeps the dump robust if a future argument holds a non-JSON-serializable value. + """ + if args is None or not accelerator.is_main_process: + return + try: + os.makedirs(model_logger.output_path, exist_ok=True) + save_path = os.path.join(model_logger.output_path, "training_args.json") + with open(save_path, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4, ensure_ascii=False, default=str) + print(f"Training arguments saved to `{save_path}`.") + except Exception as e: + print(f"Warning: failed to save training arguments: {e}") + + def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, @@ -44,6 +63,8 @@ def launch_training_task( cpu_offload_split_threshold = args.cpu_offload_split_threshold customized_optimizer = args.customized_optimizer + save_training_args(accelerator, model_logger, args) + optimizer_class = get_optimizer_class(customized_optimizer) optimizer = optimizer_class(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)