From 5ab4de7a250370c0faea57ec6b9272acc017cb64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 16 Jun 2026 18:23:15 +0200 Subject: [PATCH 1/6] feat: add logit-normal schedule --- include/stable-diffusion.h | 1 + src/runtime/denoiser.hpp | 187 +++++++++++++++++++++++++++++++++++++ src/stable-diffusion.cpp | 3 + 3 files changed, 191 insertions(+) diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 730794e6b..7e21d6624 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -70,6 +70,7 @@ enum scheduler_t { LCM_SCHEDULER, BONG_TANGENT_SCHEDULER, LTX2_SCHEDULER, + LOGIT_NORMAL_SCHEDULER, SCHEDULER_COUNT }; diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index fed5911bc..6b46b29f2 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -559,6 +559,168 @@ struct LTX2Scheduler : SigmaScheduler { } }; +/* + * Logit-Normal Scheduler + * Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py + */ +struct LogitNormalScheduler : SigmaScheduler { + float mean = 0.0f; + float std = 1.0f; + float logsnr_min = -15.0f; + float logsnr_max = 18.0f; + + LogitNormalScheduler(float mean, float std, float logsnr_min, float logsnr_max) + : mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {} + + // https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ + // translated to c++ with Qwen + double ndtri(double p) { + if (p <= 0.0) { + return -std::numeric_limits::infinity(); + } else if (p >= 1.0) { + return std::numeric_limits::infinity(); + } + + static const double p_low = 0.02425; + static const double p_high = 1.0 - p_low; + + static const double c[6] = { + -7.784894002430293e-03, + -3.223964580411365e-01, + -2.400758277161838e+00, + -2.549732539343734e+00, + 4.374664141464968e+00, + 2.938163982698783e+00}; + + static const double d[5] = { + 7.784695709041462e-03, + 3.224671290700398e-01, + 2.445134137142996e+00, + 3.754408661907416e+00, + 1.0 // Implicit +1 in denominator + }; + + // Coefficients for the central region + static const double a[6] = { + -3.969683028665376e+01, + 2.209460984245205e+02, + -2.759285104469687e+02, + 1.383577518672690e+02, + -3.066479806614716e+01, + 2.506628277459239e+00}; + + static const double b[5] = { + -5.447609879822406e+01, + 1.615858368580409e+02, + -1.556989798598866e+02, + 6.680131188771972e+01, + -1.328068155288572e+01}; + + double x = 0.0; + + if (p < p_low) { + // Lower region + double q = std::sqrt(-2.0 * std::log(p)); + + // Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5] + // Using Horner's method for polynomial evaluation + double numerator = c[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * q + c[i]; + } + + // Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1 + double denominator = d[0]; + for (int i = 1; i < 4; ++i) { + denominator = denominator * q + d[i]; + } + denominator = denominator * q + 1.0; // Add the final +1 + + x = numerator / denominator; + + } else if (p > p_high) { + // Upper region + double q = std::sqrt(-2.0 * std::log(1.0 - p)); + + // Same polynomial structure as lower region, but result is negated + double numerator = c[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * q + c[i]; + } + + double denominator = d[0]; + for (int i = 1; i < 4; ++i) { + denominator = denominator * q + d[i]; + } + denominator = denominator * q + 1.0; + + x = -(numerator / denominator); + + } else { + // Central region + double q = p - 0.5; + double r = q * q; + + // Numerator: a[0]*r^5 + a[1]*r^4 + ... + a[5] + // Then multiply by q + double numerator = a[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * r + a[i]; + } + numerator *= q; + + // Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1 + double denominator = b[0]; + for (int i = 1; i < 5; ++i) { + denominator = denominator * r + b[i]; + } + denominator = denominator * r + 1.0; // Add the final +1 + + x = numerator / denominator; + } + + return x; + } + + static float expit(float x) { + return 1.0f / (1.0f + std::exp(-x)); + } + + std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override { + std::vector sigmas; + sigmas.reserve(n + 1); + + // Precompute bounds based on logsnr + float t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + float t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + + for (uint32_t i = 0; i <= n; ++i) { + float t = static_cast(i) / static_cast(n); + + float z = ndtri(t); + + float y = mean + std * z; + + float t_ = expit(y); + + float sigma = 1.0f - t_; + + if (sigma < t_min) + sigma = t_min; + if (sigma > t_max) + sigma = t_max; + + sigmas.push_back(sigma); + } + + if (!sigmas.empty()) { + sigmas.back() = 0.0f; + } + + return sigmas; + } +}; + struct Denoiser { virtual float sigma_min() = 0; virtual float sigma_max() = 0; @@ -623,6 +785,31 @@ struct Denoiser { LOG_INFO("get_sigmas with LTX2 scheduler"); scheduler = std::make_shared(image_seq_len, extra_sample_args); break; + case LOGIT_NORMAL_SCHEDULER: { + const int known_seq_len = (512 * 512) / (16 * 16); + // todo: mu and std from extra_sample_args + float mu = 0.; + float std = 1.75; + + if(extra_sample_args) { + for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) { + if (key == "mu") { + if (!parse_strict_float(value, mu)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "std") { + if (!parse_strict_float(value, std)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } + } + } + + float mean = mu + 0.5 * std::log(static_cast(image_seq_len) / static_cast(known_seq_len)); + LOG_INFO("get_sigmas with Logit-Normal scheduler with mean=%.4f (mu=%.4f) and std=%.4f", mean, mu, std); + scheduler = std::make_shared(mean, std, -15.0f, 18.0f); + break; + } default: LOG_INFO("get_sigmas with discrete scheduler (default)"); scheduler = std::make_shared(); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 93180bfd3..c33f454e9 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = { "lcm", "bong_tangent", "ltx2", + "logit_normal", }; const char* sd_scheduler_name(enum scheduler_t scheduler) { @@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return SIMPLE_SCHEDULER; } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) { return LTX2_SCHEDULER; + } else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) { + return LOGIT_NORMAL_SCHEDULER; } return DISCRETE_SCHEDULER; } From b3e0ba85a27dfb3f6d149b2b5a8dc584651ab6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 19 Jun 2026 18:05:09 +0200 Subject: [PATCH 2/6] refactor --- src/runtime/denoiser.hpp | 172 ++++++++++++++++++++------------------- 1 file changed, 88 insertions(+), 84 deletions(-) diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index 6b46b29f2..a8cf8ff26 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -565,15 +565,65 @@ struct LTX2Scheduler : SigmaScheduler { */ struct LogitNormalScheduler : SigmaScheduler { float mean = 0.0f; - float std = 1.0f; + float std = 1.75f; float logsnr_min = -15.0f; float logsnr_max = 18.0f; - LogitNormalScheduler(float mean, float std, float logsnr_min, float logsnr_max) - : mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {} + bool resolution_aware = true; + + float t_min, t_max; + + void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) { + const int known_seq_len = (512 * 512) / (16 * 16); + if (extra_sample_args) { + for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) { + if (key == "mu") { + if (!parse_strict_float(value, mean)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "std") { + if (!parse_strict_float(value, std)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } + if (key == "logsnr_min") { + if (!parse_strict_float(value, logsnr_min)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "logsnr_max") { + if (!parse_strict_float(value, logsnr_max)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "resolution_aware") { + if (!parse_strict_bool(value, resolution_aware)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } + } + } + if (image_seq_len > 0 && resolution_aware) { + mean += 0.5 * std::log(static_cast(image_seq_len) / static_cast(known_seq_len)); + } + } + + LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f) + : mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) { + t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + } + + LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) { + mean = 0.0f; + std = 1.75f; + logsnr_min = -15.0f; + logsnr_max = 18.0f; + + parse_extra_sample_args(image_seq_len, extra_sample_args); + t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + } // https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ - // translated to c++ with Qwen double ndtri(double p) { if (p <= 0.0) { return -std::numeric_limits::infinity(); @@ -584,37 +634,33 @@ struct LogitNormalScheduler : SigmaScheduler { static const double p_low = 0.02425; static const double p_high = 1.0 - p_low; - static const double c[6] = { - -7.784894002430293e-03, - -3.223964580411365e-01, - -2.400758277161838e+00, - -2.549732539343734e+00, - 4.374664141464968e+00, - 2.938163982698783e+00}; - - static const double d[5] = { - 7.784695709041462e-03, - 3.224671290700398e-01, - 2.445134137142996e+00, - 3.754408661907416e+00, - 1.0 // Implicit +1 in denominator - }; + static const double c[6] = {-7.784894002430293e-03, + -3.223964580411365e-01, + -2.400758277161838e+00, + -2.549732539343734e+00, + 4.374664141464968e+00, + 2.938163982698783e+00}; + + static const double d[5] = {7.784695709041462e-03, + 3.224671290700398e-01, + 2.445134137142996e+00, + 3.754408661907416e+00, + 1.0}; // Coefficients for the central region - static const double a[6] = { - -3.969683028665376e+01, - 2.209460984245205e+02, - -2.759285104469687e+02, - 1.383577518672690e+02, - -3.066479806614716e+01, - 2.506628277459239e+00}; - - static const double b[5] = { - -5.447609879822406e+01, - 1.615858368580409e+02, - -1.556989798598866e+02, - 6.680131188771972e+01, - -1.328068155288572e+01}; + static const double a[6] = {-3.969683028665376e+01, + 2.209460984245205e+02, + -2.759285104469687e+02, + 1.383577518672690e+02, + -3.066479806614716e+01, + 2.506628277459239e+00}; + + static const double b[6] = {-5.447609879822406e+01, + 1.615858368580409e+02, + -1.556989798598866e+02, + 6.680131188771972e+01, + -1.328068155288572e+01, + 1.0}; double x = 0.0; @@ -623,7 +669,6 @@ struct LogitNormalScheduler : SigmaScheduler { double q = std::sqrt(-2.0 * std::log(p)); // Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5] - // Using Horner's method for polynomial evaluation double numerator = c[0]; for (int i = 1; i < 6; ++i) { numerator = numerator * q + c[i]; @@ -631,38 +676,32 @@ struct LogitNormalScheduler : SigmaScheduler { // Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1 double denominator = d[0]; - for (int i = 1; i < 4; ++i) { + for (int i = 1; i < 5; ++i) { denominator = denominator * q + d[i]; } - denominator = denominator * q + 1.0; // Add the final +1 x = numerator / denominator; - } else if (p > p_high) { // Upper region double q = std::sqrt(-2.0 * std::log(1.0 - p)); - // Same polynomial structure as lower region, but result is negated double numerator = c[0]; for (int i = 1; i < 6; ++i) { numerator = numerator * q + c[i]; } double denominator = d[0]; - for (int i = 1; i < 4; ++i) { + for (int i = 1; i < 5; ++i) { denominator = denominator * q + d[i]; } - denominator = denominator * q + 1.0; x = -(numerator / denominator); - } else { // Central region double q = p - 0.5; double r = q * q; - // Numerator: a[0]*r^5 + a[1]*r^4 + ... + a[5] - // Then multiply by q + // Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q double numerator = a[0]; for (int i = 1; i < 6; ++i) { numerator = numerator * r + a[i]; @@ -671,29 +710,20 @@ struct LogitNormalScheduler : SigmaScheduler { // Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1 double denominator = b[0]; - for (int i = 1; i < 5; ++i) { + for (int i = 1; i < 6; ++i) { denominator = denominator * r + b[i]; } - denominator = denominator * r + 1.0; // Add the final +1 x = numerator / denominator; } - return x; } - static float expit(float x) { - return 1.0f / (1.0f + std::exp(-x)); - } - std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override { std::vector sigmas; + LOG_WARN("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f (t_max=%.4f), logsnr_max=%.4f (t_min=%.4f)", mean, std, logsnr_min, t_max, logsnr_max, t_min); sigmas.reserve(n + 1); - // Precompute bounds based on logsnr - float t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); - float t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); - for (uint32_t i = 0; i <= n; ++i) { float t = static_cast(i) / static_cast(n); @@ -701,9 +731,7 @@ struct LogitNormalScheduler : SigmaScheduler { float y = mean + std * z; - float t_ = expit(y); - - float sigma = 1.0f - t_; + float sigma = 1.0f / (1.0f + exp(y)); // == 1 - sigmoid(y) if (sigma < t_min) sigma = t_min; @@ -712,11 +740,7 @@ struct LogitNormalScheduler : SigmaScheduler { sigmas.push_back(sigma); } - - if (!sigmas.empty()) { - sigmas.back() = 0.0f; - } - + sigmas[n] = 0.0f; return sigmas; } }; @@ -786,28 +810,8 @@ struct Denoiser { scheduler = std::make_shared(image_seq_len, extra_sample_args); break; case LOGIT_NORMAL_SCHEDULER: { - const int known_seq_len = (512 * 512) / (16 * 16); - // todo: mu and std from extra_sample_args - float mu = 0.; - float std = 1.75; - - if(extra_sample_args) { - for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) { - if (key == "mu") { - if (!parse_strict_float(value, mu)) { - LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); - } - } else if (key == "std") { - if (!parse_strict_float(value, std)) { - LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); - } - } - } - } - - float mean = mu + 0.5 * std::log(static_cast(image_seq_len) / static_cast(known_seq_len)); - LOG_INFO("get_sigmas with Logit-Normal scheduler with mean=%.4f (mu=%.4f) and std=%.4f", mean, mu, std); - scheduler = std::make_shared(mean, std, -15.0f, 18.0f); + LOG_INFO("get_sigmas with Logit-Normal scheduler"); + scheduler = std::make_shared(image_seq_len, extra_sample_args); break; } default: From ca0c19ee9b59604969b8d513a8b7c9df71ed419b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 19 Jun 2026 18:19:58 +0200 Subject: [PATCH 3/6] update help message --- examples/common/common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 8d85db554..9b5eb2aac 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() { &hires_upscaler}, {"", "--extra-sample-args", - "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;", + "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit-normal supports mu, std, logsnr_min, logsnr_max, resolution_aware", (int)',', &extra_sample_args}, {"", @@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() { on_high_noise_sample_method_arg}, {"", "--scheduler", - "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific", + "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific", on_scheduler_arg}, {"", "--sigmas", From f2d423e2b2a3b04f76cc8002a1ded4c43fbd46e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 20 Jun 2026 11:06:43 +0200 Subject: [PATCH 4/6] fix typo --- examples/common/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 9b5eb2aac..05dc5196e 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() { &hires_upscaler}, {"", "--extra-sample-args", - "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit-normal supports mu, std, logsnr_min, logsnr_max, resolution_aware", + "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware", (int)',', &extra_sample_args}, {"", From dd6b256494955667c2ba146e9cdc35f93272813b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 20 Jun 2026 16:22:16 +0200 Subject: [PATCH 5/6] reverse --- src/runtime/denoiser.hpp | 43 +++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index a8cf8ff26..0d18f0aa4 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -571,7 +571,7 @@ struct LogitNormalScheduler : SigmaScheduler { bool resolution_aware = true; - float t_min, t_max; + float one_minus_t_min, one_minus_t_max; void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) { const int known_seq_len = (512 * 512) / (16 * 16); @@ -606,10 +606,17 @@ struct LogitNormalScheduler : SigmaScheduler { } } + float sigmoid(float x) { + return 1.0f / (1.0f + std::exp(-x)); + } + LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f) : mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) { - t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); - t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + // t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + one_minus_t_min = sigmoid(0.5f * logsnr_max); + // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + one_minus_t_max = sigmoid(0.5f * logsnr_min); + } LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) { @@ -619,8 +626,10 @@ struct LogitNormalScheduler : SigmaScheduler { logsnr_max = 18.0f; parse_extra_sample_args(image_seq_len, extra_sample_args); - t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); - t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + // t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + one_minus_t_min = sigmoid(0.5f * logsnr_max); + // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + one_minus_t_max = sigmoid(0.5f * logsnr_min); } // https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ @@ -719,24 +728,30 @@ struct LogitNormalScheduler : SigmaScheduler { return x; } - std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override { + std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t t_to_sigma) override { std::vector sigmas; - LOG_WARN("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f (t_max=%.4f), logsnr_max=%.4f (t_min=%.4f)", mean, std, logsnr_min, t_max, logsnr_max, t_min); + LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max); sigmas.reserve(n + 1); - for (uint32_t i = 0; i <= n; ++i) { float t = static_cast(i) / static_cast(n); - float z = ndtri(t); + // ndtri(1-t) == -ndtri(t) + float z = -ndtri(t); float y = mean + std * z; - float sigma = 1.0f / (1.0f + exp(y)); // == 1 - sigmoid(y) + float timestep = sigmoid(y); + + if (timestep > one_minus_t_min) + timestep = one_minus_t_min; + if (timestep < one_minus_t_max) + timestep = one_minus_t_max; - if (sigma < t_min) - sigma = t_min; - if (sigma > t_max) - sigma = t_max; + // which one is corrrect ? + float sigma = timestep; + // float sigma = t_to_sigma(timestep * TIMESTEPS); + // float sigma = t_to_sigma(timestep * (TIMESTEPS - 1)); + // float sigma = t_to_sigma(timestep * TIMESTEPS - 1); sigmas.push_back(sigma); } From f96691d1a6284794b7acb889035e2461392ebd5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 22 Jun 2026 18:02:58 +0200 Subject: [PATCH 6/6] assume it's sigma-space --- src/runtime/denoiser.hpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index 0d18f0aa4..28b29ef27 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -728,7 +728,7 @@ struct LogitNormalScheduler : SigmaScheduler { return x; } - std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t t_to_sigma) override { + std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override { std::vector sigmas; LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max); sigmas.reserve(n + 1); @@ -747,11 +747,7 @@ struct LogitNormalScheduler : SigmaScheduler { if (timestep < one_minus_t_max) timestep = one_minus_t_max; - // which one is corrrect ? float sigma = timestep; - // float sigma = t_to_sigma(timestep * TIMESTEPS); - // float sigma = t_to_sigma(timestep * (TIMESTEPS - 1)); - // float sigma = t_to_sigma(timestep * TIMESTEPS - 1); sigmas.push_back(sigma); }