Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 106 additions & 12 deletions src/model/diffusion/ideogram4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,42 @@ namespace Ideogram4 {
int rope_theta,
const std::vector<int>& mrope_section,
bool circular_x = false,
bool circular_y = false) {
bool circular_y = false,
const std::vector<ggml_tensor*>& ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(bs == 1);
std::vector<std::vector<float>> ids(static_cast<size_t>(bs) * (context_len + grid_h * grid_w),

int total_ref_tokens = 0;
int max_h = grid_h;
int max_w = grid_w;
int index = 0;
int h_offset = 0;
int w_offset = 0;
int current_h = 0;
int current_w = 0;
for (ggml_tensor* ref : ref_latents) {
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
total_ref_tokens += ref_h * ref_w;
if (increase_ref_index) {
index += 1;
} else {
index = 1;
h_offset = 0;
w_offset = 0;
if (ref_h + current_h > ref_w + current_w) {
w_offset = current_w;
} else {
h_offset = current_h;
}
current_h = std::max(current_h, ref_h + h_offset);
current_w = std::max(current_w, ref_w + w_offset);
}
max_h = std::max(max_h, ref_h + h_offset);
max_w = std::max(max_w, ref_w + w_offset);
}

std::vector<std::vector<float>> ids(static_cast<size_t>(bs) * (context_len + grid_h * grid_w + total_ref_tokens),
std::vector<float>(3, 0.f));

for (int i = 0; i < context_len; ++i) {
Expand All @@ -171,19 +204,48 @@ namespace Ideogram4 {
}
}

index = 0;
current_h = 0;
current_w = 0;
for (ggml_tensor* ref : ref_latents) {
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
int gh_offset = 0;
int gw_offset = 0;
if (increase_ref_index) {
index += 1;
} else {
index = 1;
if (ref_h + current_h > ref_w + current_w) {
gw_offset = current_w;
} else {
gh_offset = current_h;
}
current_h = std::max(current_h, ref_h + gh_offset);
current_w = std::max(current_w, ref_w + gw_offset);
}
for (int y = 0; y < ref_h; ++y) {
for (int x = 0; x < ref_w; ++x) {
ids[cursor++] = {static_cast<float>(IMAGE_POSITION_OFFSET + index),
static_cast<float>(IMAGE_POSITION_OFFSET + gh_offset + y),
static_cast<float>(IMAGE_POSITION_OFFSET + gw_offset + x)};
}
}
}

std::vector<std::vector<int>> axis_wrap_dims(3);
if (circular_y || circular_x) {
size_t total_len = static_cast<size_t>(bs) * (context_len + grid_h * grid_w);
size_t total_len = static_cast<size_t>(bs) * (context_len + grid_h * grid_w + total_ref_tokens);
axis_wrap_dims[1].assign(total_len, 0);
axis_wrap_dims[2].assign(total_len, 0);
if (circular_y) {
for (size_t idx = static_cast<size_t>(context_len); idx < total_len; ++idx) {
axis_wrap_dims[1][idx] = grid_h;
axis_wrap_dims[1][idx] = max_h;
}
}
if (circular_x) {
for (size_t idx = static_cast<size_t>(context_len); idx < total_len; ++idx) {
axis_wrap_dims[2][idx] = grid_w;
axis_wrap_dims[2][idx] = max_w;
}
}
}
Expand Down Expand Up @@ -377,7 +439,8 @@ namespace Ideogram4 {
ggml_tensor* timestep,
ggml_tensor* context,
ggml_tensor* pe,
ggml_tensor* image_indicator_ids) {
ggml_tensor* image_indicator_ids,
std::vector<ggml_tensor*> ref_latents = {}) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t N = x->ne[3];
Expand All @@ -392,7 +455,16 @@ namespace Ideogram4 {
auto final_layer = std::dynamic_pointer_cast<Ideogram4FinalLayer>(blocks["final_layer"]);

auto img = patchify(ctx->ggml_ctx, x, config);
img = input_proj->forward(ctx, img);
int64_t n_img_tokens = img->ne[1];
img = input_proj->forward(ctx, img);

if (!ref_latents.empty()) {
for (ggml_tensor* ref : ref_latents) {
ref = patchify(ctx->ggml_ctx, ref, config);
ref = input_proj->forward(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}

ggml_tensor* h = img;
int64_t context_len = 0;
Expand All @@ -407,7 +479,7 @@ namespace Ideogram4 {
h = ggml_concat(ctx->ggml_ctx, txt, img, 1);
}

auto indicator_embedding = embed_image_indicator->forward(ctx, image_indicator_ids);
auto indicator_embedding = embed_image_indicator->forward(ctx, image_indicator_ids);https://file+.vscode-resource.vscode-cdn.net/h%3A/stable-diffusion.cpp/preview.png?version%3D1782073192254
h = ggml_add(ctx->ggml_ctx, h, indicator_embedding);

auto t_cond = t_embedding->forward(ctx, timestep);
Expand All @@ -423,6 +495,9 @@ namespace Ideogram4 {
if (context_len > 0) {
h = ggml_ext_slice(ctx->ggml_ctx, h, 1, context_len, h->ne[1]);
}
if (h->ne[1] > n_img_tokens) {
h = ggml_ext_slice(ctx->ggml_ctx, h, 1, 0, n_img_tokens);
}

h = unpatchify(ctx->ggml_ctx, h, H, W, config);
h = ggml_ext_scale(ctx->ggml_ctx, h, -1.f);
Expand Down Expand Up @@ -485,6 +560,8 @@ namespace Ideogram4 {
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor,
const std::vector<sd::Tensor<float>>& ref_latents_tensor = {},
bool increase_ref_index = false,
bool use_uncond_model = false) {
ggml_cgraph* gf = new_graph_custom(IDEOGRAM4_GRAPH_SIZE);
ggml_tensor* x = make_input(x_tensor);
Expand All @@ -499,9 +576,19 @@ namespace Ideogram4 {
context_len = context->ne[1];
}

std::vector<ggml_tensor*> ref_latents;
ref_latents.reserve(ref_latents_tensor.size());
for (const auto& ref_latent_tensor : ref_latents_tensor) {
ref_latents.push_back(make_input(ref_latent_tensor));
}

int64_t grid_w = x->ne[0];
int64_t grid_h = x->ne[1];
int64_t pos_len = context_len + grid_h * grid_w;
int64_t total_ref_tokens = 0;
for (ggml_tensor* ref : ref_latents) {
total_ref_tokens += ref->ne[0] * ref->ne[1];
}
int64_t pos_len = context_len + grid_h * grid_w + total_ref_tokens;
int64_t head_dim = config.emb_dim / config.num_heads;

auto runner_ctx = get_context();
Expand All @@ -513,7 +600,9 @@ namespace Ideogram4 {
static_cast<int>(config.rope_theta),
config.mrope_section,
runner_ctx.circular_x_enabled,
runner_ctx.circular_y_enabled);
runner_ctx.circular_y_enabled,
ref_latents,
increase_ref_index);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
set_backend_tensor_data(pe, pe_vec.data());

Expand All @@ -524,7 +613,7 @@ namespace Ideogram4 {
auto indicator = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_I32, pos_len, x->ne[3]);
set_backend_tensor_data(indicator, image_indicator_vec.data());

ggml_tensor* out = active_model.forward(&runner_ctx, x, timesteps, context, pe, indicator);
ggml_tensor* out = active_model.forward(&runner_ctx, x, timesteps, context, pe, indicator, ref_latents);
ggml_build_forward_expand(gf, out);
return gf;
}
Expand All @@ -533,9 +622,11 @@ namespace Ideogram4 {
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context,
const std::vector<sd::Tensor<float>>& ref_latents = {},
bool increase_ref_index = false,
bool use_uncond_model = false) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, use_uncond_model);
return build_graph(x, timesteps, context, ref_latents, increase_ref_index, use_uncond_model);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
}
Expand All @@ -544,11 +635,14 @@ namespace Ideogram4 {
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
static const std::vector<sd::Tensor<float>> empty_ref_latents;
bool use_uncond_model = should_use_uncond_model(diffusion_params);
return compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
tensor_or_empty(diffusion_params.context),
diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents,
diffusion_params.increase_ref_index,
use_uncond_model);
}
};
Expand Down
3 changes: 2 additions & 1 deletion src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ static bool sd_version_supports_ref_latent_img_cfg(SDVersion version) {
sd_version_is_qwen_image(version) ||
sd_version_is_longcat(version) ||
sd_version_is_z_image(version) ||
sd_version_is_boogu_image(version);
sd_version_is_boogu_image(version) ||
sd_version_is_ideogram4(version);
}

static bool sd_version_supports_img_cfg(SDVersion version, bool has_ref_images) {
Expand Down
Loading