Skip to content

Commit 555cb1f

Browse files
author
Chris Warren-Smith
committed
LLAMA: special handling for chat templates for gemma
1 parent f423fbf commit 555cb1f

3 files changed

Lines changed: 36 additions & 11 deletions

File tree

llama/llama-sb.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Llama::Llama() :
4646
_max_tokens(0),
4747
_log_level(GGML_LOG_LEVEL_CONT),
4848
_n_past(0),
49+
_is_gemma4(false),
4950
_seed(LLAMA_DEFAULT_SEED) {
5051
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
5152
Llama *llama = (Llama *)user_data;
@@ -66,6 +67,7 @@ Llama::Llama(Llama &&other) noexcept
6667
, _grammar_src(std::move(other._grammar_src))
6768
, _grammar_root(std::move(other._grammar_root))
6869
, _last_error(std::move(other._last_error))
70+
, _template(std::move(other._template))
6971
, _penalty_last_n(other._penalty_last_n)
7072
, _penalty_repeat(other._penalty_repeat)
7173
, _penalty_freq(other._penalty_freq)
@@ -77,6 +79,7 @@ Llama::Llama(Llama &&other) noexcept
7779
, _max_tokens(other._max_tokens)
7880
, _log_level(other._log_level)
7981
, _n_past(other._n_past)
82+
, _is_gemma4(other._is_gemma4)
8083
, _seed(other._seed) {
8184
}
8285

@@ -95,7 +98,7 @@ Llama::~Llama() {
9598

9699
void Llama::reset() {
97100
_stop_sequences.clear();
98-
_last_error = "";
101+
_last_error.clear();
99102
_penalty_last_n = 64;
100103
_penalty_repeat = 1.1f;
101104
_penalty_freq = 0.0f;
@@ -106,8 +109,10 @@ void Llama::reset() {
106109
_min_p = 0.0f;
107110
_max_tokens = 150;
108111
_n_past = 0;
112+
_is_gemma4 = false;
109113
_grammar_src.clear();
110114
_grammar_root.clear();
115+
_template.clear();
111116
_seed = LLAMA_DEFAULT_SEED;
112117
if (_ctx) {
113118
llama_memory_clear(llama_get_memory(_ctx), true);
@@ -142,9 +147,9 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch, int n_gpu_layer
142147
_vocab = llama_model_get_vocab(_model);
143148
}
144149
_template = llama_model_chat_template(_model, nullptr);
150+
_is_gemma4 = (_template.find("<|turn>model") != string::npos);
145151
}
146152

147-
148153
return _last_error.empty();
149154
}
150155

@@ -268,16 +273,34 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
268273
}
269274

270275
bool Llama::add_message(LlamaIter &iter, const string &role, const string &content) {
271-
llama_chat_message msg = {role.c_str(), content.c_str()};
272-
276+
llama_chat_message message = {role.c_str(), content.c_str()};
273277
int buf_size = 2 * (int)(role.size() + content.size() + 64);
274278
vector<char> buf(buf_size);
275-
bool add_ass = (role == "user");
279+
bool add_ass = (role == "user" || role == "tool");
280+
int32_t n = 0;
281+
282+
if (_template.empty()) {
283+
_last_error = "No chat template available";
284+
return false;
285+
}
276286

277-
int32_t n = llama_chat_apply_template(_template, &msg, 1, add_ass, buf.data(), buf.size());
278-
if (n > (int32_t)buf.size()) {
279-
buf.resize(n);
280-
llama_chat_apply_template(_template, &msg, 1, add_ass, buf.data(), buf.size());
287+
if (_is_gemma4) {
288+
string str = "<|turn>" + role + "\n" + content + "<turn|>\n";
289+
if (add_ass) {
290+
str += "<|turn>model\n";
291+
}
292+
n = str.size();
293+
buf.assign(str.begin(), str.end());
294+
buf.push_back('\0');
295+
} else {
296+
n = llama_chat_apply_template(_template.c_str(), &message, 1, add_ass, buf.data(), buf_size);
297+
if (n < 0) {
298+
_last_error = "No chat template no supported";
299+
return false;
300+
} else if (n > (int32_t)buf.size()) {
301+
buf.resize(n);
302+
llama_chat_apply_template(_template.c_str(), &message, 1, add_ass, buf.data(), buf.size());
303+
}
281304
}
282305
string prompt(buf.data(), n);
283306

llama/llama-sb.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct Llama {
9191
string _grammar_src;
9292
string _grammar_root;
9393
string _last_error;
94-
const char *_template;
94+
string _template;
9595
int32_t _penalty_last_n;
9696
float _penalty_repeat;
9797
float _penalty_freq;
@@ -103,5 +103,6 @@ struct Llama {
103103
int _max_tokens;
104104
int _log_level;
105105
int _n_past;
106+
bool _is_gemma4;
106107
unsigned int _seed;
107108
};

llama/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ static int cmd_llama_add_message(var_s *self, int argc, slib_par_t *arg, var_s *
413413
int iter_id = ++g_nextId;
414414
LlamaIter &iter = g_llama_iter[iter_id];
415415
Llama &llama = g_llama.at(id);
416-
auto role = get_param_str(argc, arg, 0, "");
416+
auto role = get_param_str(argc, arg, 0, "user");
417417
auto content = get_param_str(argc, arg, 1, "");
418418
if (llama.add_message(iter, role, content)) {
419419
map_init_id(retval, iter_id, CLASS_ID_LLAMA_ITER);
@@ -423,6 +423,7 @@ static int cmd_llama_add_message(var_s *self, int argc, slib_par_t *arg, var_s *
423423
v_create_callback(retval, "tokens_sec", cmd_llama_tokens_sec);
424424
result = 1;
425425
} else {
426+
g_llama_iter.erase(iter_id);
426427
error(retval, llama.last_error());
427428
}
428429
}

0 commit comments

Comments
 (0)