From 01dfe52b15a572e4e44e2aa22a1f15b2c816ff8b Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:21:47 +0000 Subject: [PATCH] Add gradient clipping options to optimizer Introduces options for clipping gradients by global norm or by value, configurable via `config.opt_enable_grad_global_norm_clipping` and `config.opt_enable_grad_clipping`, as well as `config.max_grad_norm` and `config.max_grad_value`. Co-authored-by: martinarroyo --- src/maxdiffusion/configs/base14.yml | 3 +++ src/maxdiffusion/configs/base21.yml | 3 +++ src/maxdiffusion/configs/base_2_base.yml | 3 +++ src/maxdiffusion/configs/base_flux_dev.yml | 3 +++ src/maxdiffusion/configs/base_flux_dev_multi_res.yml | 3 +++ src/maxdiffusion/configs/base_flux_schnell.yml | 3 +++ src/maxdiffusion/configs/base_wan_14b.yml | 3 +++ src/maxdiffusion/configs/base_wan_1_3b.yml | 3 +++ src/maxdiffusion/configs/base_wan_27b.yml | 3 +++ src/maxdiffusion/configs/base_wan_i2v_14b.yml | 3 +++ src/maxdiffusion/configs/base_wan_i2v_27b.yml | 3 +++ src/maxdiffusion/configs/base_xl.yml | 3 +++ src/maxdiffusion/configs/base_xl_lightning.yml | 3 +++ src/maxdiffusion/max_utils.py | 10 +++++++++- 14 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index ca2579d92..832efb343 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 65e7d19e0..84807830f 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 16948296a..95435d041 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 7a508095f..996ae177f 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -245,6 +245,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 1aba7431f..af3a0bde2 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -232,6 +232,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 9ae399713..8454d4809 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -240,6 +240,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 91a3e092a..d478673aa 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -301,6 +301,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index ffd2864a8..db7ccb266 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -257,6 +257,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 022b18c91..ba5486c10 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -268,6 +268,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 2a5b0338c..3d00364be 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -263,6 +263,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0bd6a27f2..ec308af43 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -264,6 +264,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 0 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 3dbb1578e..34c74fc6b 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -205,6 +205,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index e487559a7..b140a6c1a 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -166,6 +166,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 04b3869fe..cca8e0c99 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -483,13 +483,21 @@ def create_learning_rate_schedule(learning_rate, learning_rate_schedule_steps, w def create_optimizer(config, learning_rate_scheduler): - return optax.adamw( + opt = optax.adamw( learning_rate=learning_rate_scheduler, b1=config.adam_b1, b2=config.adam_b2, eps=config.adam_eps, weight_decay=config.adam_weight_decay, ) + if config.opt_enable_grad_global_norm_clipping: + opt = optax.chain( + optax.clip_by_global_norm(config.max_grad_norm), opt + ) + + if config.opt_enable_grad_clipping: + opt = optax.chain(optax.clip(config.max_grad_value), opt) + return opt def get_precision(config):