diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 8d85db554..05dc5196e 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() { &hires_upscaler}, {"", "--extra-sample-args", - "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;", + "extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware", (int)',', &extra_sample_args}, {"", @@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() { on_high_noise_sample_method_arg}, {"", "--scheduler", - "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific", + "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific", on_scheduler_arg}, {"", "--sigmas", diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 730794e6b..7e21d6624 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -70,6 +70,7 @@ enum scheduler_t { LCM_SCHEDULER, BONG_TANGENT_SCHEDULER, LTX2_SCHEDULER, + LOGIT_NORMAL_SCHEDULER, SCHEDULER_COUNT }; diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index fed5911bc..28b29ef27 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -559,6 +559,203 @@ struct LTX2Scheduler : SigmaScheduler { } }; +/* + * Logit-Normal Scheduler + * Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py + */ +struct LogitNormalScheduler : SigmaScheduler { + float mean = 0.0f; + float std = 1.75f; + float logsnr_min = -15.0f; + float logsnr_max = 18.0f; + + bool resolution_aware = true; + + float one_minus_t_min, one_minus_t_max; + + void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) { + const int known_seq_len = (512 * 512) / (16 * 16); + if (extra_sample_args) { + for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) { + if (key == "mu") { + if (!parse_strict_float(value, mean)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "std") { + if (!parse_strict_float(value, std)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } + if (key == "logsnr_min") { + if (!parse_strict_float(value, logsnr_min)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "logsnr_max") { + if (!parse_strict_float(value, logsnr_max)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "resolution_aware") { + if (!parse_strict_bool(value, resolution_aware)) { + LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } + } + } + if (image_seq_len > 0 && resolution_aware) { + mean += 0.5 * std::log(static_cast(image_seq_len) / static_cast(known_seq_len)); + } + } + + float sigmoid(float x) { + return 1.0f / (1.0f + std::exp(-x)); + } + + LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f) + : mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) { + // t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + one_minus_t_min = sigmoid(0.5f * logsnr_max); + // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + one_minus_t_max = sigmoid(0.5f * logsnr_min); + + } + + LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) { + mean = 0.0f; + std = 1.75f; + logsnr_min = -15.0f; + logsnr_max = 18.0f; + + parse_extra_sample_args(image_seq_len, extra_sample_args); + // t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max)); + one_minus_t_min = sigmoid(0.5f * logsnr_max); + // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); + one_minus_t_max = sigmoid(0.5f * logsnr_min); + } + + // https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ + double ndtri(double p) { + if (p <= 0.0) { + return -std::numeric_limits::infinity(); + } else if (p >= 1.0) { + return std::numeric_limits::infinity(); + } + + static const double p_low = 0.02425; + static const double p_high = 1.0 - p_low; + + static const double c[6] = {-7.784894002430293e-03, + -3.223964580411365e-01, + -2.400758277161838e+00, + -2.549732539343734e+00, + 4.374664141464968e+00, + 2.938163982698783e+00}; + + static const double d[5] = {7.784695709041462e-03, + 3.224671290700398e-01, + 2.445134137142996e+00, + 3.754408661907416e+00, + 1.0}; + + // Coefficients for the central region + static const double a[6] = {-3.969683028665376e+01, + 2.209460984245205e+02, + -2.759285104469687e+02, + 1.383577518672690e+02, + -3.066479806614716e+01, + 2.506628277459239e+00}; + + static const double b[6] = {-5.447609879822406e+01, + 1.615858368580409e+02, + -1.556989798598866e+02, + 6.680131188771972e+01, + -1.328068155288572e+01, + 1.0}; + + double x = 0.0; + + if (p < p_low) { + // Lower region + double q = std::sqrt(-2.0 * std::log(p)); + + // Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5] + double numerator = c[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * q + c[i]; + } + + // Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1 + double denominator = d[0]; + for (int i = 1; i < 5; ++i) { + denominator = denominator * q + d[i]; + } + + x = numerator / denominator; + } else if (p > p_high) { + // Upper region + double q = std::sqrt(-2.0 * std::log(1.0 - p)); + + double numerator = c[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * q + c[i]; + } + + double denominator = d[0]; + for (int i = 1; i < 5; ++i) { + denominator = denominator * q + d[i]; + } + + x = -(numerator / denominator); + } else { + // Central region + double q = p - 0.5; + double r = q * q; + + // Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q + double numerator = a[0]; + for (int i = 1; i < 6; ++i) { + numerator = numerator * r + a[i]; + } + numerator *= q; + + // Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1 + double denominator = b[0]; + for (int i = 1; i < 6; ++i) { + denominator = denominator * r + b[i]; + } + + x = numerator / denominator; + } + return x; + } + + std::vector get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override { + std::vector sigmas; + LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max); + sigmas.reserve(n + 1); + for (uint32_t i = 0; i <= n; ++i) { + float t = static_cast(i) / static_cast(n); + + // ndtri(1-t) == -ndtri(t) + float z = -ndtri(t); + + float y = mean + std * z; + + float timestep = sigmoid(y); + + if (timestep > one_minus_t_min) + timestep = one_minus_t_min; + if (timestep < one_minus_t_max) + timestep = one_minus_t_max; + + float sigma = timestep; + + sigmas.push_back(sigma); + } + sigmas[n] = 0.0f; + return sigmas; + } +}; + struct Denoiser { virtual float sigma_min() = 0; virtual float sigma_max() = 0; @@ -623,6 +820,11 @@ struct Denoiser { LOG_INFO("get_sigmas with LTX2 scheduler"); scheduler = std::make_shared(image_seq_len, extra_sample_args); break; + case LOGIT_NORMAL_SCHEDULER: { + LOG_INFO("get_sigmas with Logit-Normal scheduler"); + scheduler = std::make_shared(image_seq_len, extra_sample_args); + break; + } default: LOG_INFO("get_sigmas with discrete scheduler (default)"); scheduler = std::make_shared(); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 93180bfd3..c33f454e9 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = { "lcm", "bong_tangent", "ltx2", + "logit_normal", }; const char* sd_scheduler_name(enum scheduler_t scheduler) { @@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return SIMPLE_SCHEDULER; } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) { return LTX2_SCHEDULER; + } else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) { + return LOGIT_NORMAL_SCHEDULER; } return DISCRETE_SCHEDULER; }