1#include <android/log.h>
  2#include <jni.h>
  3#include <iomanip>
  4#include <cmath>
  5#include <string>
  6#include <unistd.h>
  7#include <sampling.h>
  8
  9#include "logging.h"
 10#include "chat.h"
 11#include "common.h"
 12#include "llama.h"
 13
 14template<class T>
 15static std::string join(const std::vector<T> &values, const std::string &delim) {
 16    std::ostringstream str;
 17    for (size_t i = 0; i < values.size(); i++) {
 18        str << values[i];
 19        if (i < values.size() - 1) { str << delim; }
 20    }
 21    return str.str();
 22}
 23
 24/**
 25 * LLama resources: context, model, batch and sampler
 26 */
 27constexpr int   N_THREADS_MIN           = 2;
 28constexpr int   N_THREADS_MAX           = 4;
 29constexpr int   N_THREADS_HEADROOM      = 2;
 30
 31constexpr int   DEFAULT_CONTEXT_SIZE    = 8192;
 32constexpr int   OVERFLOW_HEADROOM       = 4;
 33constexpr int   BATCH_SIZE              = 512;
 34constexpr float DEFAULT_SAMPLER_TEMP    = 0.3f;
 35
 36static llama_model                      * g_model;
 37static llama_context                    * g_context;
 38static llama_batch                        g_batch;
 39static common_chat_templates_ptr          g_chat_templates;
 40static common_sampler                   * g_sampler;
 41
 42extern "C"
 43JNIEXPORT void JNICALL
 44Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
 45    // Set llama log handler to Android
 46    llama_log_set(aichat_android_log_callback, nullptr);
 47
 48    // Loading all CPU backend variants
 49    const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
 50    LOGi("Loading backends from %s", path_to_backend);
 51    ggml_backend_load_all_from_path(path_to_backend);
 52    env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
 53
 54    // Initialize backends
 55    llama_backend_init();
 56    LOGi("Backend initiated; Log handler set.");
 57}
 58
 59extern "C"
 60JNIEXPORT jint JNICALL
 61Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
 62    llama_model_params model_params = llama_model_default_params();
 63
 64    const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
 65    LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
 66
 67    auto *model = llama_model_load_from_file(model_path, model_params);
 68    env->ReleaseStringUTFChars(jmodel_path, model_path);
 69    if (!model) {
 70        return 1;
 71    }
 72    g_model = model;
 73    return 0;
 74}
 75
 76static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
 77    if (!model) {
 78        LOGe("%s: model cannot be null", __func__);
 79        return nullptr;
 80    }
 81
 82    // Multi-threading setup
 83    const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
 84                                                     (int) sysconf(_SC_NPROCESSORS_ONLN) -
 85                                                     N_THREADS_HEADROOM));
 86    LOGi("%s: Using %d threads", __func__, n_threads);
 87
 88    // Context parameters setup
 89    llama_context_params ctx_params = llama_context_default_params();
 90    const int trained_context_size = llama_model_n_ctx_train(model);
 91    if (n_ctx > trained_context_size) {
 92        LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
 93             __func__, trained_context_size, n_ctx);
 94    }
 95    ctx_params.n_ctx = n_ctx;
 96    ctx_params.n_batch = BATCH_SIZE;
 97    ctx_params.n_ubatch = BATCH_SIZE;
 98    ctx_params.n_threads = n_threads;
 99    ctx_params.n_threads_batch = n_threads;
100    auto *context = llama_init_from_model(g_model, ctx_params);
101    if (context == nullptr) {
102        LOGe("%s: llama_new_context_with_model() returned null)", __func__);
103    }
104    return context;
105}
106
107static common_sampler *new_sampler(float temp) {
108    common_params_sampling sparams;
109    sparams.temp = temp;
110    return common_sampler_init(g_model, sparams);
111}
112
113extern "C"
114JNIEXPORT jint JNICALL
115Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
116    auto *context = init_context(g_model);
117    if (!context) { return 1; }
118    g_context = context;
119    g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
120    g_chat_templates = common_chat_templates_init(g_model, "");
121    g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
122    return 0;
123}
124
125static std::string get_backend() {
126    std::vector<std::string> backends;
127    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
128        auto *reg = ggml_backend_reg_get(i);
129        std::string name = ggml_backend_reg_name(reg);
130        if (name != "CPU") {
131            backends.push_back(ggml_backend_reg_name(reg));
132        }
133    }
134    return backends.empty() ? "CPU" : join(backends, ",");
135}
136
137extern "C"
138JNIEXPORT jstring JNICALL
139Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
140    return env->NewStringUTF(llama_print_system_info());
141}
142
143extern "C"
144JNIEXPORT jstring JNICALL
145Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
146                                                      jint pl, jint nr) {
147    auto *context = init_context(g_model, pp);
148    if (!context) {
149        const auto *const err_msg = "Fail to init_context! Bench aborted.";
150        LOGe(err_msg);
151        return env->NewStringUTF(err_msg);
152    }
153
154    auto pp_avg = 0.0;
155    auto tg_avg = 0.0;
156    auto pp_std = 0.0;
157    auto tg_std = 0.0;
158
159    const uint32_t n_ctx = llama_n_ctx(context);
160    LOGi("n_ctx = %d", n_ctx);
161
162    int i, j;
163    int nri;
164    for (nri = 0; nri < nr; nri++) {
165        LOGi("Benchmark prompt processing (pp = %d)", pp);
166
167        common_batch_clear(g_batch);
168
169        const int n_tokens = pp;
170        for (i = 0; i < n_tokens; i++) {
171            common_batch_add(g_batch, 0, i, {0}, false);
172        }
173
174        g_batch.logits[g_batch.n_tokens - 1] = true;
175        llama_memory_clear(llama_get_memory(context), false);
176
177        const auto t_pp_start = ggml_time_us();
178        if (llama_decode(context, g_batch) != 0) {
179            LOGe("llama_decode() failed during prompt processing");
180        }
181        const auto t_pp_end = ggml_time_us();
182
183        // bench text generation
184
185        LOGi("Benchmark text generation (tg = %d)", tg);
186
187        llama_memory_clear(llama_get_memory(context), false);
188        const auto t_tg_start = ggml_time_us();
189        for (i = 0; i < tg; i++) {
190            common_batch_clear(g_batch);
191            for (j = 0; j < pl; j++) {
192                common_batch_add(g_batch, 0, i, {j}, true);
193            }
194
195            if (llama_decode(context, g_batch) != 0) {
196                LOGe("llama_decode() failed during text generation");
197            }
198        }
199        const auto t_tg_end = ggml_time_us();
200
201        llama_memory_clear(llama_get_memory(context), false);
202
203        const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
204        const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
205
206        const auto speed_pp = double(pp) / t_pp;
207        const auto speed_tg = double(pl * tg) / t_tg;
208
209        pp_avg += speed_pp;
210        tg_avg += speed_tg;
211
212        pp_std += speed_pp * speed_pp;
213        tg_std += speed_tg * speed_tg;
214
215        LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
216    }
217
218    llama_free(context);
219
220    pp_avg /= double(nr);
221    tg_avg /= double(nr);
222
223    if (nr > 1) {
224        pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
225        tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
226    } else {
227        pp_std = 0;
228        tg_std = 0;
229    }
230
231    char model_desc[128];
232    llama_model_desc(g_model, model_desc, sizeof(model_desc));
233
234    const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
235    const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
236
237    const auto backend = get_backend();
238    std::stringstream result;
239    result << std::setprecision(3);
240    result << "| model | size | params | backend | test | t/s |\n";
241    result << "| --- | --- | --- | --- | --- | --- |\n";
242    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
243           << backend << " | pp " << pp << " | " << pp_avg << " ยฑ " << pp_std << " |\n";
244    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
245           << backend << " | tg " << tg << " | " << tg_avg << " ยฑ " << tg_std << " |\n";
246    return env->NewStringUTF(result.str().c_str());
247}
248
249
250/**
251 * Completion loop's long-term states:
252 * - chat management
253 * - position tracking
254 */
255constexpr const char *ROLE_SYSTEM       = "system";
256constexpr const char *ROLE_USER         = "user";
257constexpr const char *ROLE_ASSISTANT    = "assistant";
258
259static std::vector<common_chat_msg> chat_msgs;
260static llama_pos system_prompt_position;
261static llama_pos current_position;
262
263static void reset_long_term_states(const bool clear_kv_cache = true) {
264    chat_msgs.clear();
265    system_prompt_position = 0;
266    current_position = 0;
267
268    if (clear_kv_cache)
269        llama_memory_clear(llama_get_memory(g_context), false);
270}
271
272/**
273 * TODO-hyin: implement sliding-window version as a better alternative
274 *
275 * Context shifting by discarding the older half of the tokens appended after system prompt:
276 * - take the [system_prompt_position] first tokens from the original prompt
277 * - take half of the last (system_prompt_position - system_prompt_position) tokens
278 * - recompute the logits in batches
279 */
280static void shift_context() {
281    const int n_discard = (current_position - system_prompt_position) / 2;
282    LOGi("%s: Discarding %d tokens", __func__, n_discard);
283    llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
284    llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
285    current_position -= n_discard;
286    LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
287}
288
289static std::string chat_add_and_format(const std::string &role, const std::string &content) {
290    common_chat_msg new_msg;
291    new_msg.role = role;
292    new_msg.content = content;
293    auto formatted = common_chat_format_single(
294            g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
295    chat_msgs.push_back(new_msg);
296    LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
297    return formatted;
298}
299
300/**
301 * Completion loop's short-term states:
302 * - stop generation position
303 * - token chars caching
304 * - current assistant message being generated
305 */
306static llama_pos stop_generation_position;
307static std::string cached_token_chars;
308static std::ostringstream assistant_ss;
309
310static void reset_short_term_states() {
311    stop_generation_position = 0;
312    cached_token_chars.clear();
313    assistant_ss.str("");
314}
315
316static int decode_tokens_in_batches(
317        llama_context *context,
318        llama_batch &batch,
319        const llama_tokens &tokens,
320        const llama_pos start_pos,
321        const bool compute_last_logit = false) {
322    // Process tokens in batches using the global batch
323    LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
324    for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
325        const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
326        common_batch_clear(batch);
327        LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
328
329        // Shift context if current batch cannot fit into the context
330        if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
331            LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
332            shift_context();
333        }
334
335        // Add tokens to the batch with proper positions
336        for (int j = 0; j < cur_batch_size; j++) {
337            const llama_token token_id = tokens[i + j];
338            const llama_pos position = start_pos + i + j;
339            const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
340            common_batch_add(batch, token_id, position, {0}, want_logit);
341        }
342
343        // Decode this batch
344        const int decode_result = llama_decode(context, batch);
345        if (decode_result) {
346            LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
347            return 1;
348        }
349    }
350    return 0;
351}
352
353extern "C"
354JNIEXPORT jint JNICALL
355Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
356        JNIEnv *env,
357        jobject /*unused*/,
358        jstring jsystem_prompt
359) {
360    // Reset long-term & short-term states
361    reset_long_term_states();
362    reset_short_term_states();
363
364    // Obtain system prompt from JEnv
365    const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
366    LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
367    std::string formatted_system_prompt(system_prompt);
368    env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
369
370    // Format system prompt if applicable
371    const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
372    if (has_chat_template) {
373        formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
374    }
375
376    // Tokenize system prompt
377    const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
378                                               has_chat_template, has_chat_template);
379    for (auto id: system_tokens) {
380        LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
381    }
382
383    // Handle context overflow
384    const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
385    if ((int) system_tokens.size() > max_batch_size) {
386        LOGe("%s: System prompt too long for context! %d tokens, max: %d",
387             __func__, (int) system_tokens.size(), max_batch_size);
388        return 1;
389    }
390
391    // Decode system tokens in batches
392    if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
393        LOGe("%s: llama_decode() failed!", __func__);
394        return 2;
395    }
396
397    // Update position
398    system_prompt_position = current_position = (int) system_tokens.size();
399    return 0;
400}
401
402extern "C"
403JNIEXPORT jint JNICALL
404Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
405        JNIEnv *env,
406        jobject /*unused*/,
407        jstring juser_prompt,
408        jint n_predict
409) {
410    // Reset short-term states
411    reset_short_term_states();
412
413    // Obtain and tokenize user prompt
414    const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
415    LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
416    std::string formatted_user_prompt(user_prompt);
417    env->ReleaseStringUTFChars(juser_prompt, user_prompt);
418
419    // Format user prompt if applicable
420    const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
421    if (has_chat_template) {
422        formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
423    }
424
425    // Decode formatted user prompts
426    auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
427    for (auto id: user_tokens) {
428        LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
429    }
430
431    // Ensure user prompt doesn't exceed the context size by truncating if necessary.
432    const int user_prompt_size = (int) user_tokens.size();
433    const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
434    if (user_prompt_size > max_batch_size) {
435        const int skipped_tokens = user_prompt_size - max_batch_size;
436        user_tokens.resize(max_batch_size);
437        LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
438    }
439
440    // Decode user tokens in batches
441    if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
442        LOGe("%s: llama_decode() failed!", __func__);
443        return 2;
444    }
445
446    // Update position
447    current_position += user_prompt_size;
448    stop_generation_position = current_position + user_prompt_size + n_predict;
449    return 0;
450}
451
452static bool is_valid_utf8(const char *string) {
453    if (!string) { return true; }
454
455    const auto *bytes = (const unsigned char *) string;
456    int num;
457
458    while (*bytes != 0x00) {
459        if ((*bytes & 0x80) == 0x00) {
460            // U+0000 to U+007F
461            num = 1;
462        } else if ((*bytes & 0xE0) == 0xC0) {
463            // U+0080 to U+07FF
464            num = 2;
465        } else if ((*bytes & 0xF0) == 0xE0) {
466            // U+0800 to U+FFFF
467            num = 3;
468        } else if ((*bytes & 0xF8) == 0xF0) {
469            // U+10000 to U+10FFFF
470            num = 4;
471        } else {
472            return false;
473        }
474
475        bytes += 1;
476        for (int i = 1; i < num; ++i) {
477            if ((*bytes & 0xC0) != 0x80) {
478                return false;
479            }
480            bytes += 1;
481        }
482    }
483    return true;
484}
485
486extern "C"
487JNIEXPORT jstring JNICALL
488Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
489        JNIEnv *env,
490        jobject /*unused*/
491) {
492    // Infinite text generation via context shifting
493    if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
494        LOGw("%s: Context full! Shifting...", __func__);
495        shift_context();
496    }
497
498    // Stop if reaching the marked position
499    if (current_position >= stop_generation_position) {
500        LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
501        return nullptr;
502    }
503
504    // Sample next token
505    const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
506    common_sampler_accept(g_sampler, new_token_id, true);
507
508    // Populate the batch with new token, then decode
509    common_batch_clear(g_batch);
510    common_batch_add(g_batch, new_token_id, current_position, {0}, true);
511    if (llama_decode(g_context, g_batch) != 0) {
512        LOGe("%s: llama_decode() failed for generated token", __func__);
513        return nullptr;
514    }
515
516    // Update position
517    current_position++;
518
519    // Stop if next token is EOG
520    if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
521        LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
522        chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
523        return nullptr;
524    }
525
526    // If not EOG, convert to text
527    auto new_token_chars = common_token_to_piece(g_context, new_token_id);
528    cached_token_chars += new_token_chars;
529
530    // Create and return a valid UTF-8 Java string
531    jstring result = nullptr;
532    if (is_valid_utf8(cached_token_chars.c_str())) {
533        result = env->NewStringUTF(cached_token_chars.c_str());
534        LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
535
536        assistant_ss << cached_token_chars;
537        cached_token_chars.clear();
538    } else {
539        LOGv("id: %d,\tappend to cache", new_token_id);
540        result = env->NewStringUTF("");
541    }
542    return result;
543}
544
545
546extern "C"
547JNIEXPORT void JNICALL
548Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
549    // Reset long-term & short-term states
550    reset_long_term_states();
551    reset_short_term_states();
552
553    // Free up resources
554    common_sampler_free(g_sampler);
555    g_chat_templates.reset();
556    llama_batch_free(g_batch);
557    llama_free(g_context);
558    llama_model_free(g_model);
559}
560
561extern "C"
562JNIEXPORT void JNICALL
563Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) {
564    llama_backend_free();
565}