Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 54 additions & 98 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "common/common.hpp"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
#include "image_metadata.h"

const char* previews_str[] = {
Expand Down Expand Up @@ -275,7 +276,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
}

bool load_images_from_dir(const std::string dir,
std::vector<sd_image_t>& images,
SDImageVec& images,
int expected_width = 0,
int expected_height = 0,
int max_image_num = 0,
Expand Down Expand Up @@ -317,7 +318,7 @@ bool load_images_from_dir(const std::string dir,
3,
image_buffer});

if (max_image_num > 0 && images.size() >= max_image_num) {
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
break;
}
}
Expand Down Expand Up @@ -554,39 +555,17 @@ int main(int argc, const char* argv[]) {
}
}

bool vae_decode_only = true;
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames;

auto release_all_resources = [&]() {
free(init_image.data);
free(end_image.data);
free(control_image.data);
free(mask_image.data);
for (auto image : ref_images) {
free(image.data);
image.data = nullptr;
}
ref_images.clear();
for (auto image : pmid_images) {
free(image.data);
image.data = nullptr;
}
pmid_images.clear();
for (auto image : control_frames) {
free(image.data);
image.data = nullptr;
}
control_frames.clear();
};
bool vae_decode_only = true;
SDImageOwner init_image({0, 0, 3, nullptr});
SDImageOwner end_image({0, 0, 3, nullptr});
SDImageOwner control_image({0, 0, 3, nullptr});
SDImageOwner mask_image({0, 0, 1, nullptr});
SDImageVec ref_images;
SDImageVec pmid_images;
SDImageVec control_frames;

auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
SDImageOwner& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
Expand All @@ -596,13 +575,12 @@ int main(int argc, const char* argv[]) {
expected_height = gen_params.height;
}

if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}

gen_params.set_width_and_height_if_unset(image.width, image.height);
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
return true;
};

Expand All @@ -623,47 +601,46 @@ int main(int argc, const char* argv[]) {
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
SDImageOwner ref_image({0, 0, 3, nullptr});
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
ref_images.push_back(std::move(ref_image));
}
}

if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(&mask_image,
if (!load_sd_image_from_file(mask_image.put(),
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1;
}
} else {
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (mask_image.data == nullptr) {
sd_image_t generated_mask = {0, 0, 1, nullptr};
generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (generated_mask.data == nullptr) {
LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1;
}
mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
generated_mask.width = gen_params.get_resolved_width();
generated_mask.height = gen_params.get_resolved_height();
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
mask_image.reset(generated_mask);
}

if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(&control_image,
if (!load_sd_image_from_file(control_image.put(),
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1;
}
if (cli_params.canny_preprocess) { // apply preprocessor
preprocess_canny(control_image,
preprocess_canny(control_image.get(),
0.08f,
0.08f,
0.8f,
Expand All @@ -679,7 +656,6 @@ int main(int argc, const char* argv[]) {
gen_params.get_resolved_height(),
gen_params.video_frames,
cli_params.verbose)) {
release_all_resources();
return 1;
}
}
Expand All @@ -691,7 +667,6 @@ int main(int argc, const char* argv[]) {
0,
0,
cli_params.verbose)) {
release_all_resources();
return 1;
}
}
Expand All @@ -702,39 +677,30 @@ int main(int argc, const char* argv[]) {

sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);

sd_image_t* results = nullptr;
int num_results = 0;
SDImageVec results;
int num_results = 0;

if (cli_params.mode == UPSCALE) {
num_results = 1;
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
if (results == nullptr) {
LOG_INFO("failed to allocate results array");
release_all_resources();
return 1;
}

results[0] = init_image;
init_image.data = nullptr;
results.push_back(init_image.release());
} else {
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));

if (sd_ctx == nullptr) {
LOG_INFO("new_sd_ctx_t failed");
release_all_resources();
return 1;
}

if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
}

if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
}

if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method);
}

if (cli_params.mode == IMG_GEN) {
Expand All @@ -744,19 +710,19 @@ int main(int argc, const char* argv[]) {
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
init_image.get(),
ref_images.data(),
(int)ref_images.size(),
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
mask_image.get(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
gen_params.batch_count,
control_image,
control_image.get(),
gen_params.control_strength,
{
pmid_images.data(),
Expand All @@ -768,17 +734,17 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params,
};

results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
} else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
end_image,
init_image.get(),
end_image.get(),
control_frames.data(),
(int)control_frames.size(),
gen_params.get_resolved_width(),
Expand All @@ -794,25 +760,23 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params,
};

results = generate_video(sd_ctx, &vid_gen_params, &num_results);
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
results.adopt(generated_video, num_results);
}

if (results == nullptr) {
if (!results) {
LOG_ERROR("generate failed");
free_sd_ctx(sd_ctx);
return 1;
}

free_sd_ctx(sd_ctx);
}

int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct,
ctx_params.n_threads,
gen_params.upscale_tile_size);
UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct,
ctx_params.n_threads,
gen_params.upscale_tile_size));

if (upscaler_ctx == nullptr) {
LOG_ERROR("new_upscaler_ctx failed");
Expand All @@ -821,32 +785,24 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) {
continue;
}
sd_image_t current_image = results[i];
SDImageOwner current_image(results[i]);
results[i] = {0, 0, 0, nullptr};
for (int u = 0; u < gen_params.upscale_repeats; ++u) {
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
if (upscaled_image.data == nullptr) {
SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor));
if (upscaled_image.get().data == nullptr) {
LOG_ERROR("upscale failed");
break;
}
free(current_image.data);
current_image = upscaled_image;
current_image = std::move(upscaled_image);
}
results[i] = current_image; // Set the final upscaled image as the result
results[i] = current_image.release(); // Set the final upscaled image as the result
}
}
}

if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
return 1;
}

for (int i = 0; i < num_results; i++) {
free(results[i].data);
results[i].data = nullptr;
}
free(results);

release_all_resources();

return 0;
}
11 changes: 5 additions & 6 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace fs = std::filesystem;
#endif // _WIN32

#include "log.h"
#include "resource_owners.hpp"
#include "stable-diffusion.h"

#define SAFE_STR(s) ((s) ? (s) : "")
Expand Down Expand Up @@ -1751,8 +1752,8 @@ struct SDGenerationParams {
}

std::string to_string() const {
char* sample_params_str = sd_sample_params_to_str(&sample_params);
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);
FreeUniquePtr<char> sample_params_str(sd_sample_params_to_str(&sample_params));
FreeUniquePtr<char> high_noise_sample_params_str(sd_sample_params_to_str(&high_noise_sample_params));

std::ostringstream lora_ss;
lora_ss << "{\n";
Expand Down Expand Up @@ -1801,9 +1802,9 @@ struct SDGenerationParams {
<< " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n"
<< " pm_style_strength: " << pm_style_strength << ",\n"
<< " skip_layers: " << vec_to_string(skip_layers) << ",\n"
<< " sample_params: " << sample_params_str << ",\n"
<< " sample_params: " << SAFE_STR(sample_params_str.get()) << ",\n"
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
<< " high_noise_sample_params: " << SAFE_STR(high_noise_sample_params_str.get()) << ",\n"
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
<< " cache_mode: \"" << cache_mode << "\",\n"
<< " cache_option: \"" << cache_option << "\",\n"
Expand All @@ -1829,8 +1830,6 @@ struct SDGenerationParams {
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< "}";
free(sample_params_str);
free(high_noise_sample_params_str);
return oss.str();
}
};
Expand Down
Loading
Loading