From c347306e9f29a0ad1e761a577f153e3fe91f40fe Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Tue, 9 Jun 2026 01:02:35 -0700 Subject: [PATCH] feat: save training args to output_path/training_args.json Persist the parsed training arguments as `training_args.json` under `--output_path` at the start of every training run, so experiments are reproducible without having to recover the original launch command. The dump happens in `launch_training_task`, which all 14 model training entrypoints route through, so a single helper covers every model. Only the main process writes the file (avoids multi-process races), and `args=None` is handled gracefully. All training args are plain str/int/float/bool/None, so `vars(args)` is directly JSON-serializable. Closes #1484 --- diffsynth/diffusion/runner.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 38e81092..342d9472 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -1,4 +1,4 @@ -import os, torch, importlib +import os, json, torch, importlib from tqdm import tqdm from accelerate import Accelerator from .training_module import DiffusionTrainingModule @@ -16,6 +16,25 @@ def get_optimizer_class(customized_optimizer=None): return getattr(module, class_name) +def save_training_args(accelerator: Accelerator, model_logger: ModelLogger, args): + """Dump the parsed training arguments to ``output_path/training_args.json`` for reproducibility. + + Saving is best-effort: a failure here (a non-serializable value, a permission error, a full disk, ...) + should never interrupt training, so any exception is caught and reported as a warning. ``default=str`` + keeps the dump robust if a future argument holds a non-JSON-serializable value. + """ + if args is None or not accelerator.is_main_process: + return + try: + os.makedirs(model_logger.output_path, exist_ok=True) + save_path = os.path.join(model_logger.output_path, "training_args.json") + with open(save_path, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4, ensure_ascii=False, default=str) + print(f"Training arguments saved to `{save_path}`.") + except Exception as e: + print(f"Warning: failed to save training arguments: {e}") + + def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, @@ -44,6 +63,8 @@ def launch_training_task( cpu_offload_split_threshold = args.cpu_offload_split_threshold customized_optimizer = args.customized_optimizer + save_training_args(accelerator, model_logger, args) + optimizer_class = get_optimizer_class(customized_optimizer) optimizer = optimizer_class(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)