Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/runtime/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,19 @@ struct Denoiser {
virtual sd::Tensor<float> inverse_noise_scaling(float sigma,
const sd::Tensor<float>& latent) = 0;

virtual float apply_shift_warp(float sigma) {
return sigma;
}

virtual std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version, const char* extra_sample_args = nullptr) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
bool shifted_scheduler = false;
switch (scheduler_type) {
case DISCRETE_SCHEDULER:
LOG_INFO("get_sigmas with discrete scheduler");
scheduler = std::make_shared<DiscreteScheduler>();
shifted_scheduler = true;
break;
case KARRAS_SCHEDULER:
LOG_INFO("get_sigmas with Karras scheduler");
Expand All @@ -598,14 +604,17 @@ struct Denoiser {
case SGM_UNIFORM_SCHEDULER:
LOG_INFO("get_sigmas with SGM Uniform scheduler");
scheduler = std::make_shared<SGMUniformScheduler>();
shifted_scheduler = true;
break;
case SIMPLE_SCHEDULER:
LOG_INFO("get_sigmas with Simple scheduler");
scheduler = std::make_shared<SimpleScheduler>();
shifted_scheduler = true;
break;
case SMOOTHSTEP_SCHEDULER:
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
shifted_scheduler = true;
break;
case BONG_TANGENT_SCHEDULER:
LOG_INFO("get_sigmas with bong_tangent scheduler");
Expand All @@ -618,6 +627,7 @@ struct Denoiser {
case LCM_SCHEDULER:
LOG_INFO("get_sigmas with LCM scheduler");
scheduler = std::make_shared<LCMScheduler>();
shifted_scheduler = true;
break;
case LTX2_SCHEDULER:
LOG_INFO("get_sigmas with LTX2 scheduler");
Expand All @@ -628,7 +638,14 @@ struct Denoiser {
scheduler = std::make_shared<DiscreteScheduler>();
break;
}
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
std::vector<float> sigmas = scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);

if (!shifted_scheduler) {
for (size_t i = 0; i < sigmas.size() - 1; ++i) {
sigmas[i] = apply_shift_warp(sigmas[i]);
}
}
return sigmas;
}
};

Expand Down Expand Up @@ -763,6 +780,10 @@ struct DiscreteFlowDenoiser : public Denoiser {
float sigma_to_t(float sigma) override {
return sigma * 1000.f;
}

float apply_shift_warp(float sigma) override {
return time_snr_shift(shift, sigma);
}

float t_to_sigma(float t) override {
t = t + 1;
Expand Down Expand Up @@ -798,9 +819,13 @@ struct FluxFlowDenoiser : public DiscreteFlowDenoiser {
return sigma;
}

float apply_shift_warp(float sigma) override {
return flux_time_shift(shift, 1.0f, sigma);
}

float t_to_sigma(float t) override {
t = t + 1;
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
return apply_shift_warp(t / TIMESTEPS);
}
};

Expand Down
Loading