Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
(int)',',
&extra_sample_args},
{"",
Expand Down Expand Up @@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
on_high_noise_sample_method_arg},
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific",
on_scheduler_arg},
{"",
"--sigmas",
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum scheduler_t {
LCM_SCHEDULER,
BONG_TANGENT_SCHEDULER,
LTX2_SCHEDULER,
LOGIT_NORMAL_SCHEDULER,
SCHEDULER_COUNT
};

Expand Down
202 changes: 202 additions & 0 deletions src/runtime/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,203 @@ struct LTX2Scheduler : SigmaScheduler {
}
};

/*
* Logit-Normal Scheduler
* Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py
*/
struct LogitNormalScheduler : SigmaScheduler {
float mean = 0.0f;
float std = 1.75f;
float logsnr_min = -15.0f;
float logsnr_max = 18.0f;

bool resolution_aware = true;

float one_minus_t_min, one_minus_t_max;

void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
const int known_seq_len = (512 * 512) / (16 * 16);
if (extra_sample_args) {
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) {
if (key == "mu") {
if (!parse_strict_float(value, mean)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "std") {
if (!parse_strict_float(value, std)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
}
if (key == "logsnr_min") {
if (!parse_strict_float(value, logsnr_min)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "logsnr_max") {
if (!parse_strict_float(value, logsnr_max)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "resolution_aware") {
if (!parse_strict_bool(value, resolution_aware)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
}
}
}
if (image_seq_len > 0 && resolution_aware) {
mean += 0.5 * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
}
}

float sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x));
}

LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f)
: mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min);

}

LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
mean = 0.0f;
std = 1.75f;
logsnr_min = -15.0f;
logsnr_max = 18.0f;

parse_extra_sample_args(image_seq_len, extra_sample_args);
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min);
}

// https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/
double ndtri(double p) {
if (p <= 0.0) {
return -std::numeric_limits<double>::infinity();
} else if (p >= 1.0) {
return std::numeric_limits<double>::infinity();
}

static const double p_low = 0.02425;
static const double p_high = 1.0 - p_low;

static const double c[6] = {-7.784894002430293e-03,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00};

static const double d[5] = {7.784695709041462e-03,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
1.0};

// Coefficients for the central region
static const double a[6] = {-3.969683028665376e+01,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383577518672690e+02,
-3.066479806614716e+01,
2.506628277459239e+00};

static const double b[6] = {-5.447609879822406e+01,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
1.0};

double x = 0.0;

if (p < p_low) {
// Lower region
double q = std::sqrt(-2.0 * std::log(p));

// Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5]
double numerator = c[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * q + c[i];
}

// Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1
double denominator = d[0];
for (int i = 1; i < 5; ++i) {
denominator = denominator * q + d[i];
}

x = numerator / denominator;
} else if (p > p_high) {
// Upper region
double q = std::sqrt(-2.0 * std::log(1.0 - p));

double numerator = c[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * q + c[i];
}

double denominator = d[0];
for (int i = 1; i < 5; ++i) {
denominator = denominator * q + d[i];
}

x = -(numerator / denominator);
} else {
// Central region
double q = p - 0.5;
double r = q * q;

// Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q
double numerator = a[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * r + a[i];
}
numerator *= q;

// Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1
double denominator = b[0];
for (int i = 1; i < 6; ++i) {
denominator = denominator * r + b[i];
}

x = numerator / denominator;
}
return x;
}

std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> sigmas;
LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max);
sigmas.reserve(n + 1);
for (uint32_t i = 0; i <= n; ++i) {
float t = static_cast<float>(i) / static_cast<float>(n);

// ndtri(1-t) == -ndtri(t)
float z = -ndtri(t);

float y = mean + std * z;

float timestep = sigmoid(y);

if (timestep > one_minus_t_min)
timestep = one_minus_t_min;
if (timestep < one_minus_t_max)
timestep = one_minus_t_max;

float sigma = timestep;

sigmas.push_back(sigma);
}
sigmas[n] = 0.0f;
return sigmas;
}
};

struct Denoiser {
virtual float sigma_min() = 0;
virtual float sigma_max() = 0;
Expand Down Expand Up @@ -623,6 +820,11 @@ struct Denoiser {
LOG_INFO("get_sigmas with LTX2 scheduler");
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
break;
case LOGIT_NORMAL_SCHEDULER: {
LOG_INFO("get_sigmas with Logit-Normal scheduler");
scheduler = std::make_shared<LogitNormalScheduler>(image_seq_len, extra_sample_args);
break;
}
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
Expand Down
3 changes: 3 additions & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = {
"lcm",
"bong_tangent",
"ltx2",
"logit_normal",
};

const char* sd_scheduler_name(enum scheduler_t scheduler) {
Expand Down Expand Up @@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return SIMPLE_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
return LTX2_SCHEDULER;
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
return LOGIT_NORMAL_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}
Expand Down
Loading