Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/conditioning/conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,101 @@ struct T5CLIPEmbedder : public Conditioner {
}
};

struct MiniT2IConditioner : public Conditioner {
T5UniGramTokenizer tokenizer;
std::shared_ptr<T5Runner> t5;
size_t prompt_length = 256;

MiniT2IConditioner(ggml_backend_t backend,
const String2TensorStorage& tensor_storage_map = {},
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr) {
bool use_t5 = false;
for (const auto& pair : tensor_storage_map) {
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
use_t5 = true;
break;
}
}
if (!use_t5) {
LOG_WARN("IMPORTANT NOTICE: No MiniT2I T5 text encoder provided, cannot process prompts!");
return;
}
t5 = std::make_shared<T5Runner>(backend, tensor_storage_map, "text_encoders.t5xxl.transformer", false, weight_manager);
}

void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
if (t5) {
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
}
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}

void set_stream_layers_enabled(bool enabled) override {
if (t5) {
t5->set_stream_layers_enabled(enabled);
}
}

void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
}
}

void runner_done() override {
if (t5) {
t5->runner_done();
}
}

SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
SDCondition result;
if (!t5) {
result.c_crossattn = sd::Tensor<float>::zeros({1024, static_cast<int64_t>(prompt_length)});
result.c_vector = sd::Tensor<float>::zeros({static_cast<int64_t>(prompt_length)});
return result;
}

std::vector<int> tokens = tokenizer.encode(conditioner_params.text);
if (tokens.size() > prompt_length) {
tokens.resize(prompt_length);
}
std::vector<float> mask(tokens.size(), 1.0f);
while (tokens.size() < prompt_length) {
tokens.push_back(tokenizer.PAD_TOKEN_ID);
mask.push_back(0.0f);
}

sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
std::vector<float> t5_mask(mask.size(), 0.0f);
for (size_t i = 0; i < mask.size(); ++i) {
t5_mask[i] = mask[i] > 0.0f ? 0.0f : -HUGE_VALF;
}
sd::Tensor<float> hidden_states = t5->compute(n_threads,
input_ids,
sd::Tensor<float>::from_vector(t5_mask),
false,
true,
true);
GGML_ASSERT(!hidden_states.empty());
result.c_crossattn = std::move(hidden_states);
result.c_vector = sd::Tensor<float>::from_vector(mask);
return result;
}
};

struct AnimaConditioner : public Conditioner {
std::shared_ptr<BPETokenizer> qwen_tokenizer;
T5UniGramTokenizer t5_tokenizer;
Expand Down
81 changes: 79 additions & 2 deletions src/core/ggml_extend_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,67 @@ static std::string resolve_first_device_by_type(enum ggml_backend_dev_type type)
if (dev == nullptr) {
return "";
}
return ggml_backend_dev_name(dev);
const char* dev_name = ggml_backend_dev_name(dev);
if (dev_name != nullptr && dev_name[0] != '\0') {
return dev_name;
}
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
const char* reg_name = reg != nullptr ? ggml_backend_reg_name(reg) : nullptr;
return reg_name != nullptr ? reg_name : "";
}

static ggml_backend_dev_t resolve_first_device_by_registry_name(const std::string& name) {
std::string lower = lower_copy(trim_copy(name));
if (lower == "metal") {
lower = "mtl";
}
if (lower.empty()) {
return nullptr;
}

const size_t device_count = ggml_backend_dev_count();
for (size_t i = 0; i < device_count; ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
if (reg == nullptr) {
continue;
}
const char* reg_name = ggml_backend_reg_name(reg);
if (reg_name != nullptr && lower_copy(reg_name) == lower) {
return dev;
}
}
return nullptr;
}

static ggml_backend_dev_t resolve_device_by_name(const std::string& name) {
const std::string lower = lower_copy(trim_copy(name));
if (lower.empty()) {
return nullptr;
}

const size_t device_count = ggml_backend_dev_count();
for (size_t i = 0; i < device_count; ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
const char* dev_name = ggml_backend_dev_name(dev);
if (dev_name != nullptr && lower_copy(dev_name) == lower) {
return dev;
}
}
return nullptr;
}

static std::string backend_device_name(ggml_backend_dev_t dev) {
if (dev == nullptr) {
return "";
}
const char* name = ggml_backend_dev_name(dev);
if (name != nullptr && name[0] != '\0') {
return name;
}
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
const char* reg_name = reg != nullptr ? ggml_backend_reg_name(reg) : nullptr;
return reg_name != nullptr ? reg_name : "";
}

static ggml_backend_buffer_t ggml_backend_tensor_buffer(const struct ggml_tensor* tensor) {
Expand Down Expand Up @@ -296,6 +356,10 @@ std::string sd_backend_resolve_name(const std::string& name) {
return resolve_first_device_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);
}

if (ggml_backend_dev_t dev = resolve_first_device_by_registry_name(requested)) {
return backend_device_name(dev);
}

const size_t device_count = ggml_backend_dev_count();
for (size_t i = 0; i < device_count; ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
Expand Down Expand Up @@ -328,7 +392,20 @@ static ggml_backend_t init_named_backend(const std::string& name) {
return ggml_backend_init_best();
}

if (ggml_backend_dev_t dev = resolve_device_by_name(name)) {
return ggml_backend_dev_init(dev, nullptr);
}
if (ggml_backend_dev_t dev = resolve_first_device_by_registry_name(name)) {
return ggml_backend_dev_init(dev, nullptr);
}

std::string resolved = sd_backend_resolve_name(name);
if (ggml_backend_dev_t dev = resolve_device_by_name(resolved)) {
return ggml_backend_dev_init(dev, nullptr);
}
if (ggml_backend_dev_t dev = resolve_first_device_by_registry_name(resolved)) {
return ggml_backend_dev_init(dev, nullptr);
}
if (resolved.empty()) {
return nullptr;
}
Expand Down Expand Up @@ -599,7 +676,7 @@ bool SDBackendManager::validate(std::string* error) const {
}
return false;
}
if (!sd_backend_resolve_name(name).empty()) {
if (!sd_backend_resolve_name(name).empty() || resolve_first_device_by_registry_name(name) != nullptr) {
return true;
}
if (error != nullptr) {
Expand Down
9 changes: 9 additions & 0 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ enum SDVersion {
VERSION_OVIS_IMAGE,
VERSION_ERNIE_IMAGE,
VERSION_LENS,
VERSION_MINIT2I,
VERSION_LONGCAT,
VERSION_PID,
VERSION_IDEOGRAM4,
Expand Down Expand Up @@ -164,6 +165,13 @@ static inline bool sd_version_is_lens(SDVersion version) {
return false;
}

static inline bool sd_version_is_minit2i(SDVersion version) {
if (version == VERSION_MINIT2I) {
return true;
}
return false;
}

static inline bool sd_version_is_pid(SDVersion version) {
if (version == VERSION_PID) {
return true;
Expand Down Expand Up @@ -208,6 +216,7 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_z_image(version) ||
sd_version_is_ernie_image(version) ||
sd_version_is_lens(version) ||
sd_version_is_minit2i(version) ||
sd_version_is_longcat(version) ||
sd_version_is_pid(version) ||
sd_version_is_ideogram4(version)) {
Expand Down
Loading