Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
(int)',',
&extra_sample_args},
{"",
Expand Down
102 changes: 92 additions & 10 deletions src/runtime/guidance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cstdlib>
#include <string>
#include <utility>
#include <optional>

#include "core/util.h"

Expand Down Expand Up @@ -63,15 +64,93 @@ namespace sd::guidance {
return uncond;
}

std::vector<float> parse_guidance_schedule_from_spec(std::string spec) {
std::vector<float> schedule;

while (!spec.empty()) {
auto sep = spec.find('+');
auto segment = spec.substr(0, sep);

auto x = segment.find('x');
if (x == std::string::npos) {
LOG_ERROR("Invalid guidance schedule segment: '%s' (expected <guidance>x<count>)", segment.c_str());
return {};
}

float guidance;
int count;

auto guidance_str = segment.substr(0, x);
auto count_str = segment.substr(x + 1);

try {
size_t idx = 0;
guidance = std::stof(guidance_str, &idx);
if (idx != guidance_str.size()) {
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
return {};
}
} catch (const std::exception&) {
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
return {};
}

try {
size_t idx = 0;
count = std::stoi(count_str, &idx);
if (idx != count_str.size()) {
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
return {};
}
} catch (const std::exception&) {
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
return {};
}

if (count <= 0) {
LOG_ERROR("Guidance schedule count must be positive");
return {};
}

schedule.insert(schedule.end(), count, guidance);

if (sep == std::string::npos) {
break;
}

spec = spec.substr(sep + 1);
}

return schedule;
}

std::vector<float> parse_guidance_schedule(const char* extra_sample_args) {
std::vector<float> guidance_schedule;
std::string guidance_schedule_str = "";
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "extra sample arg")) {
float parsed = 0.0f;
if (key == "guidance_schedule") {
guidance_schedule_str = value;
}
}

if (!guidance_schedule_str.empty()) {
guidance_schedule = parse_guidance_schedule_from_spec(guidance_schedule_str);
}
return guidance_schedule;
}

ClassifierFreeGuidance::ClassifierFreeGuidance(float guidance_scale,
float image_guidance_scale)
: guidance_scale_(guidance_scale),
image_guidance_scale_(image_guidance_scale) {
}

GuiderOutput ClassifierFreeGuidance::forward(const GuidanceInput& input,
GuiderOutput previous) const {
GuiderOutput previous,
std::optional<float> scale_override) const {
(void)previous;
float guidance_scale = scale_override.value_or(guidance_scale_);

GuiderOutput output;
if (!has_tensor(input.pred_cond)) {
Expand All @@ -86,14 +165,14 @@ namespace sd::guidance {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_uncond);
guidance_scale * (pred_cond - pred_uncond);

} else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
}
} else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
}

return output;
Expand Down Expand Up @@ -128,8 +207,10 @@ namespace sd::guidance {
}

GuiderOutput AdaptiveProjectedGuidance::forward(const GuidanceInput& input,
GuiderOutput previous) const {
GuiderOutput previous,
std::optional<float> scale_override) const {
(void)previous;
float guidance_scale = scale_override.value_or(guidance_scale_);

GuiderOutput output;
if (!has_tensor(input.pred_cond)) {
Expand All @@ -144,13 +225,13 @@ namespace sd::guidance {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_uncond);
guidance_scale * (pred_cond - pred_uncond);
} else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
}
} else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
}
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_uncond)) {
return output;
Expand All @@ -162,7 +243,7 @@ namespace sd::guidance {
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
pred_uncond,
pred_img_uncond,
guidance_scale_,
guidance_scale,
image_guidance_scale_);
if (params_.momentum != 0.0f) {
if (momentum_buffer_.shape() != deltas.shape()) {
Expand Down Expand Up @@ -239,7 +320,8 @@ namespace sd::guidance {
}

GuiderOutput SkipLayerGuidance::forward(const GuidanceInput& input,
GuiderOutput output) const {
GuiderOutput output,
std::optional<float> /*scale_override*/) const {
if (scale_ == 0.0f || !is_enabled_for_step(input) || !input.predict_skip_layer) {
return output;
}
Expand Down
14 changes: 10 additions & 4 deletions src/runtime/guidance.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstddef>
#include <functional>
#include <vector>
#include <optional>

#include "core/tensor.hpp"

Expand All @@ -27,6 +28,7 @@ namespace sd::guidance {
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args);
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params);
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args);
std::vector<float> parse_guidance_schedule(const char* extra_sample_args);

struct GuidanceInput {
int step = 0;
Expand All @@ -42,7 +44,8 @@ namespace sd::guidance {
public:
virtual ~BaseGuidance() = default;
virtual GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const = 0;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const = 0;
};

class ClassifierFreeGuidance : public BaseGuidance {
Expand All @@ -54,7 +57,8 @@ namespace sd::guidance {
float image_guidance_scale);

GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};

class AdaptiveProjectedGuidance : public BaseGuidance {
Expand All @@ -69,7 +73,8 @@ namespace sd::guidance {
AdaptiveProjectedGuidanceParams params);

GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};

class SkipLayerGuidance : public BaseGuidance {
Expand All @@ -88,7 +93,8 @@ namespace sd::guidance {
const std::vector<int>& layers() const;

GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};

} // namespace sd::guidance
Expand Down
30 changes: 29 additions & 1 deletion src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,32 @@ class StableDiffusionGGML {
float img_cfg_scale = guidance.img_cfg;
float slg_scale = guidance.slg.scale;
bool slg_uncond = sd::guidance::parse_skip_layer_guidance_uncond_arg(extra_sample_args);

std::vector<float> guidance_schedule = sd::guidance::parse_guidance_schedule(extra_sample_args);
if(!guidance_schedule.empty() && guidance_schedule.size() != sigmas.size() - 1) {
if(guidance_schedule.size() > sigmas.size()) {
LOG_WARN("guidance_schedule length (%zu) is greater than number of steps (%zu)", guidance_schedule.size(), sigmas.size() - 1);
LOG_WARN("truncating guidance_schedule to match step count");
guidance_schedule.resize(sigmas.size() - 1);
} else {
LOG_INFO("padding guidance_schedule with cfg_scale");
while(guidance_schedule.size() < sigmas.size() - 1) {
guidance_schedule.push_back(cfg_scale);
}
}
}

if(!guidance_schedule.empty()) {
std::string schedule_str = "[";
for(size_t i = 0; i < guidance_schedule.size(); ++i) {
schedule_str += std::to_string(guidance_schedule[i]);
if(i < guidance_schedule.size() - 1) {
schedule_str += ", ";
}
}
schedule_str += "]";
LOG_DEBUG("using guidance schedule: %s", schedule_str.c_str());
}

sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version,
cache_params,
Expand Down Expand Up @@ -2182,7 +2208,9 @@ class StableDiffusionGGML {
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
guidance_input.pred_img_uncond = img_uncond_out.empty() ? nullptr : &img_uncond_out;

sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {});
sd::guidance::GuiderOutput guided = guidance_schedule.empty()?
primary_guidance.forward(guidance_input, {}):
primary_guidance.forward(guidance_input, {}, guidance_schedule[guidance_schedule.size() - 1 - step]);
if (guided.pred.empty()) {
return {};
}
Expand Down
Loading