1#include <android/log.h>
2#include <jni.h>
3#include <iomanip>
4#include <cmath>
5#include <string>
6#include <unistd.h>
7#include <sampling.h>
8
9#include "logging.h"
10#include "chat.h"
11#include "common.h"
12#include "llama.h"
13
14template<class T>
15static std::string join(const std::vector<T> &values, const std::string &delim) {
16 std::ostringstream str;
17 for (size_t i = 0; i < values.size(); i++) {
18 str << values[i];
19 if (i < values.size() - 1) { str << delim; }
20 }
21 return str.str();
22}
23
24/**
25 * LLama resources: context, model, batch and sampler
26 */
27constexpr int N_THREADS_MIN = 2;
28constexpr int N_THREADS_MAX = 4;
29constexpr int N_THREADS_HEADROOM = 2;
30
31constexpr int DEFAULT_CONTEXT_SIZE = 8192;
32constexpr int OVERFLOW_HEADROOM = 4;
33constexpr int BATCH_SIZE = 512;
34constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
35
36static llama_model * g_model;
37static llama_context * g_context;
38static llama_batch g_batch;
39static common_chat_templates_ptr g_chat_templates;
40static common_sampler * g_sampler;
41
42extern "C"
43JNIEXPORT void JNICALL
44Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
45 // Set llama log handler to Android
46 llama_log_set(aichat_android_log_callback, nullptr);
47
48 // Loading all CPU backend variants
49 const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
50 LOGi("Loading backends from %s", path_to_backend);
51 ggml_backend_load_all_from_path(path_to_backend);
52 env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
53
54 // Initialize backends
55 llama_backend_init();
56 LOGi("Backend initiated; Log handler set.");
57}
58
59extern "C"
60JNIEXPORT jint JNICALL
61Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
62 llama_model_params model_params = llama_model_default_params();
63
64 const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
65 LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
66
67 auto *model = llama_model_load_from_file(model_path, model_params);
68 env->ReleaseStringUTFChars(jmodel_path, model_path);
69 if (!model) {
70 return 1;
71 }
72 g_model = model;
73 return 0;
74}
75
76static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
77 if (!model) {
78 LOGe("%s: model cannot be null", __func__);
79 return nullptr;
80 }
81
82 // Multi-threading setup
83 const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
84 (int) sysconf(_SC_NPROCESSORS_ONLN) -
85 N_THREADS_HEADROOM));
86 LOGi("%s: Using %d threads", __func__, n_threads);
87
88 // Context parameters setup
89 llama_context_params ctx_params = llama_context_default_params();
90 const int trained_context_size = llama_model_n_ctx_train(model);
91 if (n_ctx > trained_context_size) {
92 LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
93 __func__, trained_context_size, n_ctx);
94 }
95 ctx_params.n_ctx = n_ctx;
96 ctx_params.n_batch = BATCH_SIZE;
97 ctx_params.n_ubatch = BATCH_SIZE;
98 ctx_params.n_threads = n_threads;
99 ctx_params.n_threads_batch = n_threads;
100 auto *context = llama_init_from_model(g_model, ctx_params);
101 if (context == nullptr) {
102 LOGe("%s: llama_new_context_with_model() returned null)", __func__);
103 }
104 return context;
105}
106
107static common_sampler *new_sampler(float temp) {
108 common_params_sampling sparams;
109 sparams.temp = temp;
110 return common_sampler_init(g_model, sparams);
111}
112
113extern "C"
114JNIEXPORT jint JNICALL
115Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
116 auto *context = init_context(g_model);
117 if (!context) { return 1; }
118 g_context = context;
119 g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
120 g_chat_templates = common_chat_templates_init(g_model, "");
121 g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
122 return 0;
123}
124
125static std::string get_backend() {
126 std::vector<std::string> backends;
127 for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
128 auto *reg = ggml_backend_reg_get(i);
129 std::string name = ggml_backend_reg_name(reg);
130 if (name != "CPU") {
131 backends.push_back(ggml_backend_reg_name(reg));
132 }
133 }
134 return backends.empty() ? "CPU" : join(backends, ",");
135}
136
137extern "C"
138JNIEXPORT jstring JNICALL
139Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
140 return env->NewStringUTF(llama_print_system_info());
141}
142
143extern "C"
144JNIEXPORT jstring JNICALL
145Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
146 jint pl, jint nr) {
147 auto *context = init_context(g_model, pp);
148 if (!context) {
149 const auto *const err_msg = "Fail to init_context! Bench aborted.";
150 LOGe(err_msg);
151 return env->NewStringUTF(err_msg);
152 }
153
154 auto pp_avg = 0.0;
155 auto tg_avg = 0.0;
156 auto pp_std = 0.0;
157 auto tg_std = 0.0;
158
159 const uint32_t n_ctx = llama_n_ctx(context);
160 LOGi("n_ctx = %d", n_ctx);
161
162 int i, j;
163 int nri;
164 for (nri = 0; nri < nr; nri++) {
165 LOGi("Benchmark prompt processing (pp = %d)", pp);
166
167 common_batch_clear(g_batch);
168
169 const int n_tokens = pp;
170 for (i = 0; i < n_tokens; i++) {
171 common_batch_add(g_batch, 0, i, {0}, false);
172 }
173
174 g_batch.logits[g_batch.n_tokens - 1] = true;
175 llama_memory_clear(llama_get_memory(context), false);
176
177 const auto t_pp_start = ggml_time_us();
178 if (llama_decode(context, g_batch) != 0) {
179 LOGe("llama_decode() failed during prompt processing");
180 }
181 const auto t_pp_end = ggml_time_us();
182
183 // bench text generation
184
185 LOGi("Benchmark text generation (tg = %d)", tg);
186
187 llama_memory_clear(llama_get_memory(context), false);
188 const auto t_tg_start = ggml_time_us();
189 for (i = 0; i < tg; i++) {
190 common_batch_clear(g_batch);
191 for (j = 0; j < pl; j++) {
192 common_batch_add(g_batch, 0, i, {j}, true);
193 }
194
195 if (llama_decode(context, g_batch) != 0) {
196 LOGe("llama_decode() failed during text generation");
197 }
198 }
199 const auto t_tg_end = ggml_time_us();
200
201 llama_memory_clear(llama_get_memory(context), false);
202
203 const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
204 const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
205
206 const auto speed_pp = double(pp) / t_pp;
207 const auto speed_tg = double(pl * tg) / t_tg;
208
209 pp_avg += speed_pp;
210 tg_avg += speed_tg;
211
212 pp_std += speed_pp * speed_pp;
213 tg_std += speed_tg * speed_tg;
214
215 LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
216 }
217
218 llama_free(context);
219
220 pp_avg /= double(nr);
221 tg_avg /= double(nr);
222
223 if (nr > 1) {
224 pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
225 tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
226 } else {
227 pp_std = 0;
228 tg_std = 0;
229 }
230
231 char model_desc[128];
232 llama_model_desc(g_model, model_desc, sizeof(model_desc));
233
234 const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
235 const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
236
237 const auto backend = get_backend();
238 std::stringstream result;
239 result << std::setprecision(3);
240 result << "| model | size | params | backend | test | t/s |\n";
241 result << "| --- | --- | --- | --- | --- | --- |\n";
242 result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
243 << backend << " | pp " << pp << " | " << pp_avg << " ยฑ " << pp_std << " |\n";
244 result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
245 << backend << " | tg " << tg << " | " << tg_avg << " ยฑ " << tg_std << " |\n";
246 return env->NewStringUTF(result.str().c_str());
247}
248
249
250/**
251 * Completion loop's long-term states:
252 * - chat management
253 * - position tracking
254 */
255constexpr const char *ROLE_SYSTEM = "system";
256constexpr const char *ROLE_USER = "user";
257constexpr const char *ROLE_ASSISTANT = "assistant";
258
259static std::vector<common_chat_msg> chat_msgs;
260static llama_pos system_prompt_position;
261static llama_pos current_position;
262
263static void reset_long_term_states(const bool clear_kv_cache = true) {
264 chat_msgs.clear();
265 system_prompt_position = 0;
266 current_position = 0;
267
268 if (clear_kv_cache)
269 llama_memory_clear(llama_get_memory(g_context), false);
270}
271
272/**
273 * TODO-hyin: implement sliding-window version as a better alternative
274 *
275 * Context shifting by discarding the older half of the tokens appended after system prompt:
276 * - take the [system_prompt_position] first tokens from the original prompt
277 * - take half of the last (system_prompt_position - system_prompt_position) tokens
278 * - recompute the logits in batches
279 */
280static void shift_context() {
281 const int n_discard = (current_position - system_prompt_position) / 2;
282 LOGi("%s: Discarding %d tokens", __func__, n_discard);
283 llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
284 llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
285 current_position -= n_discard;
286 LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
287}
288
289static std::string chat_add_and_format(const std::string &role, const std::string &content) {
290 common_chat_msg new_msg;
291 new_msg.role = role;
292 new_msg.content = content;
293 auto formatted = common_chat_format_single(
294 g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
295 chat_msgs.push_back(new_msg);
296 LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
297 return formatted;
298}
299
300/**
301 * Completion loop's short-term states:
302 * - stop generation position
303 * - token chars caching
304 * - current assistant message being generated
305 */
306static llama_pos stop_generation_position;
307static std::string cached_token_chars;
308static std::ostringstream assistant_ss;
309
310static void reset_short_term_states() {
311 stop_generation_position = 0;
312 cached_token_chars.clear();
313 assistant_ss.str("");
314}
315
316static int decode_tokens_in_batches(
317 llama_context *context,
318 llama_batch &batch,
319 const llama_tokens &tokens,
320 const llama_pos start_pos,
321 const bool compute_last_logit = false) {
322 // Process tokens in batches using the global batch
323 LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
324 for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
325 const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
326 common_batch_clear(batch);
327 LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
328
329 // Shift context if current batch cannot fit into the context
330 if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
331 LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
332 shift_context();
333 }
334
335 // Add tokens to the batch with proper positions
336 for (int j = 0; j < cur_batch_size; j++) {
337 const llama_token token_id = tokens[i + j];
338 const llama_pos position = start_pos + i + j;
339 const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
340 common_batch_add(batch, token_id, position, {0}, want_logit);
341 }
342
343 // Decode this batch
344 const int decode_result = llama_decode(context, batch);
345 if (decode_result) {
346 LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
347 return 1;
348 }
349 }
350 return 0;
351}
352
353extern "C"
354JNIEXPORT jint JNICALL
355Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
356 JNIEnv *env,
357 jobject /*unused*/,
358 jstring jsystem_prompt
359) {
360 // Reset long-term & short-term states
361 reset_long_term_states();
362 reset_short_term_states();
363
364 // Obtain system prompt from JEnv
365 const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
366 LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
367 std::string formatted_system_prompt(system_prompt);
368 env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
369
370 // Format system prompt if applicable
371 const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
372 if (has_chat_template) {
373 formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
374 }
375
376 // Tokenize system prompt
377 const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
378 has_chat_template, has_chat_template);
379 for (auto id: system_tokens) {
380 LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
381 }
382
383 // Handle context overflow
384 const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
385 if ((int) system_tokens.size() > max_batch_size) {
386 LOGe("%s: System prompt too long for context! %d tokens, max: %d",
387 __func__, (int) system_tokens.size(), max_batch_size);
388 return 1;
389 }
390
391 // Decode system tokens in batches
392 if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
393 LOGe("%s: llama_decode() failed!", __func__);
394 return 2;
395 }
396
397 // Update position
398 system_prompt_position = current_position = (int) system_tokens.size();
399 return 0;
400}
401
402extern "C"
403JNIEXPORT jint JNICALL
404Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
405 JNIEnv *env,
406 jobject /*unused*/,
407 jstring juser_prompt,
408 jint n_predict
409) {
410 // Reset short-term states
411 reset_short_term_states();
412
413 // Obtain and tokenize user prompt
414 const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
415 LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
416 std::string formatted_user_prompt(user_prompt);
417 env->ReleaseStringUTFChars(juser_prompt, user_prompt);
418
419 // Format user prompt if applicable
420 const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
421 if (has_chat_template) {
422 formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
423 }
424
425 // Decode formatted user prompts
426 auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
427 for (auto id: user_tokens) {
428 LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
429 }
430
431 // Ensure user prompt doesn't exceed the context size by truncating if necessary.
432 const int user_prompt_size = (int) user_tokens.size();
433 const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
434 if (user_prompt_size > max_batch_size) {
435 const int skipped_tokens = user_prompt_size - max_batch_size;
436 user_tokens.resize(max_batch_size);
437 LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
438 }
439
440 // Decode user tokens in batches
441 if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
442 LOGe("%s: llama_decode() failed!", __func__);
443 return 2;
444 }
445
446 // Update position
447 current_position += user_prompt_size;
448 stop_generation_position = current_position + user_prompt_size + n_predict;
449 return 0;
450}
451
452static bool is_valid_utf8(const char *string) {
453 if (!string) { return true; }
454
455 const auto *bytes = (const unsigned char *) string;
456 int num;
457
458 while (*bytes != 0x00) {
459 if ((*bytes & 0x80) == 0x00) {
460 // U+0000 to U+007F
461 num = 1;
462 } else if ((*bytes & 0xE0) == 0xC0) {
463 // U+0080 to U+07FF
464 num = 2;
465 } else if ((*bytes & 0xF0) == 0xE0) {
466 // U+0800 to U+FFFF
467 num = 3;
468 } else if ((*bytes & 0xF8) == 0xF0) {
469 // U+10000 to U+10FFFF
470 num = 4;
471 } else {
472 return false;
473 }
474
475 bytes += 1;
476 for (int i = 1; i < num; ++i) {
477 if ((*bytes & 0xC0) != 0x80) {
478 return false;
479 }
480 bytes += 1;
481 }
482 }
483 return true;
484}
485
486extern "C"
487JNIEXPORT jstring JNICALL
488Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
489 JNIEnv *env,
490 jobject /*unused*/
491) {
492 // Infinite text generation via context shifting
493 if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
494 LOGw("%s: Context full! Shifting...", __func__);
495 shift_context();
496 }
497
498 // Stop if reaching the marked position
499 if (current_position >= stop_generation_position) {
500 LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
501 return nullptr;
502 }
503
504 // Sample next token
505 const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
506 common_sampler_accept(g_sampler, new_token_id, true);
507
508 // Populate the batch with new token, then decode
509 common_batch_clear(g_batch);
510 common_batch_add(g_batch, new_token_id, current_position, {0}, true);
511 if (llama_decode(g_context, g_batch) != 0) {
512 LOGe("%s: llama_decode() failed for generated token", __func__);
513 return nullptr;
514 }
515
516 // Update position
517 current_position++;
518
519 // Stop if next token is EOG
520 if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
521 LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
522 chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
523 return nullptr;
524 }
525
526 // If not EOG, convert to text
527 auto new_token_chars = common_token_to_piece(g_context, new_token_id);
528 cached_token_chars += new_token_chars;
529
530 // Create and return a valid UTF-8 Java string
531 jstring result = nullptr;
532 if (is_valid_utf8(cached_token_chars.c_str())) {
533 result = env->NewStringUTF(cached_token_chars.c_str());
534 LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
535
536 assistant_ss << cached_token_chars;
537 cached_token_chars.clear();
538 } else {
539 LOGv("id: %d,\tappend to cache", new_token_id);
540 result = env->NewStringUTF("");
541 }
542 return result;
543}
544
545
546extern "C"
547JNIEXPORT void JNICALL
548Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
549 // Reset long-term & short-term states
550 reset_long_term_states();
551 reset_short_term_states();
552
553 // Free up resources
554 common_sampler_free(g_sampler);
555 g_chat_templates.reset();
556 llama_batch_free(g_batch);
557 llama_free(g_context);
558 llama_model_free(g_model);
559}
560
561extern "C"
562JNIEXPORT void JNICALL
563Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) {
564 llama_backend_free();
565}