1#include "server-context.h"
   2#include "server-common.h"
   3#include "server-http.h"
   4#include "server-task.h"
   5#include "server-queue.h"
   6
   7#include "common.h"
   8#include "llama.h"
   9#include "log.h"
  10#include "sampling.h"
  11#include "speculative.h"
  12#include "mtmd.h"
  13#include "mtmd-helper.h"
  14
  15#include <cstddef>
  16#include <cinttypes>
  17#include <memory>
  18#include <filesystem>
  19
  20// fix problem with std::min and std::max
  21#if defined(_WIN32)
  22#define WIN32_LEAN_AND_MEAN
  23#ifndef NOMINMAX
  24#   define NOMINMAX
  25#endif
  26#include <windows.h>
  27#endif
  28
  29using json = nlohmann::ordered_json;
  30
  31constexpr int HTTP_POLLING_SECONDS = 1;
  32
  33// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
  34enum slot_state {
  35    SLOT_STATE_IDLE,
  36    SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
  37    SLOT_STATE_STARTED,    // after assigning a task and about to process prompt
  38    SLOT_STATE_PROCESSING_PROMPT,
  39    SLOT_STATE_DONE_PROMPT,
  40    SLOT_STATE_GENERATING,
  41};
  42
  43enum server_state {
  44    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
  45    SERVER_STATE_READY,          // Server is ready and model is loaded
  46};
  47
  48struct server_slot {
  49    int id;
  50
  51    // TODO: change to unique_ptrs for consistency:
  52    llama_context * ctx = nullptr;
  53
  54    // multimodal
  55    mtmd_context * mctx = nullptr;
  56
  57    common_speculative * spec = nullptr;
  58
  59    // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
  60    //       see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
  61    std::unique_ptr<const server_task> task;
  62    std::unique_ptr<const server_task> task_prev; // used for debugging
  63
  64    // used to determine the slot that has been used the longest
  65    int64_t t_last_used = -1;
  66
  67    // generation props
  68    int32_t n_ctx       = 0;  // context size per slot
  69    int32_t n_keep      = 0;
  70    int32_t n_decoded   = 0;
  71    int32_t n_remaining = -1;
  72    int32_t i_batch     = -1;
  73
  74    int32_t n_prompt_tokens_cache     = 0;
  75    int32_t n_prompt_tokens_processed = 0;
  76
  77    size_t last_nl_pos = 0;
  78
  79    std::string  generated_text;
  80    llama_tokens generated_tokens;
  81
  82    // idx of draft tokens in the main batch
  83    // non-empty if we went to evaluate draft tokens
  84    // ref: https://github.com/ggml-org/llama.cpp/pull/17808
  85    std::vector<int32_t> i_batch_dft;
  86
  87    std::vector<completion_token_output> generated_token_probs;
  88
  89    bool has_next_token = true;
  90    bool has_new_line   = false;
  91    bool truncated      = false;
  92
  93    stop_type stop;
  94
  95    std::string stopping_word;
  96
  97    // state
  98    slot_state state = SLOT_STATE_IDLE;
  99
 100    server_prompt prompt;
 101
 102    void prompt_save(server_prompt_cache & prompt_cache) const {
 103        GGML_ASSERT(prompt.data.size() == 0);
 104
 105        const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
 106
 107        SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
 108                (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
 109
 110        auto * cur = prompt_cache.alloc(prompt, cur_size);
 111        if (cur == nullptr) {
 112            return;
 113        }
 114
 115        llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
 116    }
 117
 118    bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
 119        bool res = prompt_cache.load(prompt, tokens, ctx, id);
 120        if (!res) {
 121            SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
 122        }
 123
 124        return res;
 125    }
 126
 127    void prompt_clear(bool allow_processing) {
 128        if (!allow_processing) {
 129            GGML_ASSERT(!is_processing());
 130        }
 131
 132        SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
 133
 134        llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
 135        prompt.tokens.clear();
 136    }
 137
 138    std::vector<common_adapter_lora_info> lora;
 139    int32_t alora_invocation_start = -1;
 140
 141    // sampling
 142    json json_schema;
 143
 144    common_sampler_ptr smpl;
 145
 146    llama_token  sampled; // in speculative mode, this is the last accepted token
 147    llama_tokens drafted;
 148
 149    // stats
 150    size_t n_sent_text = 0; // number of sent text character
 151
 152    int64_t t_start_process_prompt;
 153    int64_t t_start_generation;
 154
 155    double t_prompt_processing; // ms
 156    double t_token_generation;  // ms
 157
 158    std::function<void(int /* id_slot */)> callback_on_release;
 159
 160    // Speculative decoding stats
 161    int32_t n_draft_total = 0;      // Total draft tokens generated
 162    int32_t n_draft_accepted = 0;   // Draft tokens actually accepted
 163
 164    void reset() {
 165        SLT_DBG(*this, "%s", "\n");
 166
 167        n_prompt_tokens_cache = 0;
 168
 169        last_nl_pos    = 0;
 170        generated_text = "";
 171        has_new_line   = false;
 172        truncated      = false;
 173        stop           = STOP_TYPE_NONE;
 174        stopping_word  = "";
 175        n_sent_text    = 0;
 176
 177        drafted.clear();
 178        i_batch_dft.clear();
 179        generated_tokens.clear();
 180        generated_token_probs.clear();
 181        json_schema = json();
 182
 183        // clear speculative decoding stats
 184        n_draft_total = 0;
 185        n_draft_accepted = 0;
 186
 187        task_prev = std::move(task);
 188        task.reset();
 189
 190        llama_set_sampler(ctx, id, nullptr);
 191
 192        // clear alora start
 193        alora_invocation_start = -1;
 194    }
 195
 196    void init_sampler() const {
 197        common_sampler_reset(smpl.get());
 198
 199        if (!task->need_sampling()) {
 200            return;
 201        }
 202
 203        const int64_t t_start = ggml_time_us();
 204
 205        int n_text = 0;
 206
 207        for (int i = 0; i < (int) prompt.tokens.size(); i++) {
 208            const llama_token id = prompt.tokens[i];
 209
 210            if (id != LLAMA_TOKEN_NULL) {
 211                common_sampler_accept(smpl.get(), id, false);
 212                n_text++;
 213            }
 214        }
 215
 216        SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
 217                (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
 218    }
 219
 220    // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
 221    // also we cannot split if the pooling would require any past tokens
 222    bool can_split() const {
 223        GGML_ASSERT(task);
 224
 225        return
 226            !task->need_embd() ||
 227            (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
 228    }
 229
 230    bool can_batch_with(server_slot & other_slot) const {
 231        GGML_ASSERT(task);
 232
 233        return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
 234    }
 235
 236    bool has_budget(const common_params & global_params) {
 237        GGML_ASSERT(task);
 238
 239        if (task->params.n_predict == -1 && global_params.n_predict == -1) {
 240            return true; // limitless
 241        }
 242
 243        n_remaining = -1;
 244
 245        if (task->params.n_predict != -1) {
 246            n_remaining = task->params.n_predict - n_decoded;
 247        } else if (global_params.n_predict != -1) {
 248            n_remaining = global_params.n_predict - n_decoded;
 249        }
 250
 251        return n_remaining > 0; // no budget
 252    }
 253
 254    bool is_processing() const {
 255        return state != SLOT_STATE_IDLE;
 256    }
 257
 258    bool can_speculate() const {
 259        return !!spec;
 260    }
 261
 262    void add_token(const completion_token_output & token) {
 263        if (!is_processing()) {
 264            SLT_WRN(*this, "%s", "slot is not processing\n");
 265            return;
 266        }
 267
 268        generated_token_probs.push_back(token);
 269    }
 270
 271    int get_n_draft_max() const {
 272        GGML_ASSERT(task);
 273
 274        if (!can_speculate()) {
 275            return 0;
 276        }
 277
 278        // determine the max draft that fits the current slot state
 279        int n_draft_max = task->params.speculative.n_max;
 280
 281        // note: slot.prompt is not yet expanded with the `id` token sampled above
 282        //       also, need to leave space for 1 extra token to allow context shifts
 283        n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
 284
 285        if (n_remaining > 0) {
 286            n_draft_max = std::min(n_draft_max, n_remaining - 1);
 287        }
 288
 289        SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
 290
 291        if (n_draft_max < task->params.speculative.n_min) {
 292            SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
 293            n_draft_max = 0;
 294        }
 295
 296        return n_draft_max;
 297    }
 298
 299    void release() {
 300        if (is_processing()) {
 301            GGML_ASSERT(task);
 302
 303            SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
 304
 305            t_last_used        =  ggml_time_us();
 306            t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
 307
 308            state = SLOT_STATE_IDLE;
 309
 310            // do not keep context of the child slots - the parent's context is enough
 311            if (task->is_child()) {
 312                prompt_clear(false);
 313            }
 314
 315            reset();
 316
 317            callback_on_release(id);
 318        }
 319    }
 320
 321    result_timings get_timings() const {
 322        result_timings timings;
 323        timings.cache_n = n_prompt_tokens_cache;
 324
 325        timings.prompt_n            = n_prompt_tokens_processed;
 326        timings.prompt_ms           = t_prompt_processing;
 327        timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
 328        timings.prompt_per_second   = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
 329
 330        timings.predicted_n            = n_decoded;
 331        timings.predicted_ms           = t_token_generation;
 332        timings.predicted_per_token_ms = t_token_generation / n_decoded;
 333        timings.predicted_per_second   = 1e3 / t_token_generation * n_decoded;
 334
 335        // Add speculative metrics
 336        if (n_draft_total > 0) {
 337            timings.draft_n          = n_draft_total;
 338            timings.draft_n_accepted = n_draft_accepted;
 339        }
 340
 341        return timings;
 342    }
 343
 344    size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
 345        GGML_ASSERT(task);
 346
 347        size_t stop_pos = std::string::npos;
 348
 349        for (const std::string & word : task->params.antiprompt) {
 350            size_t pos;
 351
 352            if (is_full_stop) {
 353                const size_t tmp      = word.size() + last_token_size;
 354                const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
 355
 356                pos = text.find(word, from_pos);
 357            } else {
 358                // otherwise, partial stop
 359                pos = string_find_partial_stop(text, word);
 360            }
 361
 362            if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
 363                if (is_full_stop) {
 364                    stop           = STOP_TYPE_WORD;
 365                    stopping_word  = word;
 366                    has_next_token = false;
 367                }
 368                stop_pos = pos;
 369            }
 370        }
 371
 372        return stop_pos;
 373    }
 374
 375    void print_timings() const {
 376        const double t_prompt        =       t_prompt_processing / n_prompt_tokens_processed;
 377        const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
 378
 379        const double t_gen        =       t_token_generation / n_decoded;
 380        const double n_gen_second = 1e3 / t_token_generation * n_decoded;
 381
 382        SLT_INF(*this,
 383                "\n"
 384                "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
 385                "       eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
 386                "      total time = %10.2f ms / %5d tokens\n",
 387                t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
 388                t_token_generation, n_decoded, t_gen, n_gen_second,
 389                t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
 390
 391        if (n_draft_total > 0) {
 392            const float draft_ratio = (float) n_draft_accepted / n_draft_total;
 393            SLT_CNT(*this,
 394                    "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
 395                    draft_ratio, n_draft_accepted, n_draft_total
 396            );
 397        }
 398
 399        common_speculative_print_stats(spec);
 400    }
 401
 402    json to_json(bool only_metrics = false) const {
 403        json res;
 404
 405        res = {
 406            {"id",            id},
 407            {"n_ctx",         n_ctx},
 408            {"speculative",   can_speculate()},
 409            {"is_processing", is_processing()},
 410        };
 411
 412        const auto & ptask = task ? task : task_prev;
 413
 414        if (ptask) {
 415            res["id_task"] = ptask->id;
 416            res["params"] = ptask->params.to_json(only_metrics);
 417            res["next_token"] = {
 418                {
 419                    {"has_next_token", has_next_token},
 420                    {"has_new_line",   has_new_line},
 421                    {"n_remain",       n_remaining},
 422                    {"n_decoded",      n_decoded},
 423                }
 424            };
 425
 426            if (!only_metrics) {
 427                res["prompt"] = ptask->tokens.detokenize(ctx, true);
 428                res["generated"] = generated_text;
 429            }
 430        }
 431
 432        return res;
 433    }
 434
 435    void copy_state_to(server_slot & other) const {
 436        GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
 437
 438        llama_memory_seq_rm(llama_get_memory(ctx), other.id,     -1, -1);
 439        llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);
 440
 441        other.n_decoded   = n_decoded;
 442        other.n_remaining = n_remaining;
 443        other.i_batch     = i_batch;
 444
 445        other.t_start_process_prompt    = t_start_process_prompt;
 446        other.t_prompt_processing       = t_prompt_processing;
 447        other.n_prompt_tokens_cache     = n_prompt_tokens_cache;
 448        other.n_prompt_tokens_processed = n_prompt_tokens_processed;
 449
 450        other.prompt = prompt.clone();
 451        other.init_sampler();
 452    }
 453};
 454
 455
 456
 457//
 458// server_metrics
 459//
 460
 461struct server_metrics {
 462    int64_t t_start = 0;
 463
 464    uint64_t n_prompt_tokens_processed_total = 0;
 465    uint64_t t_prompt_processing_total       = 0;
 466    uint64_t n_tokens_predicted_total        = 0;
 467    uint64_t t_tokens_generation_total       = 0;
 468
 469    uint64_t n_tokens_max = 0;
 470
 471    uint64_t n_prompt_tokens_processed = 0;
 472    uint64_t t_prompt_processing       = 0;
 473
 474    uint64_t n_tokens_predicted  = 0;
 475    uint64_t t_tokens_generation = 0;
 476
 477    uint64_t n_decode_total     = 0;
 478    uint64_t n_busy_slots_total = 0;
 479
 480    void init() {
 481        t_start = ggml_time_us();
 482    }
 483
 484    void on_prompt_eval(const server_slot & slot) {
 485        n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
 486        n_prompt_tokens_processed       += slot.n_prompt_tokens_processed;
 487        t_prompt_processing             += slot.t_prompt_processing;
 488        t_prompt_processing_total       += slot.t_prompt_processing;
 489
 490        n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
 491    }
 492
 493    void on_prediction(const server_slot & slot) {
 494        n_tokens_predicted_total   += slot.n_decoded;
 495        n_tokens_predicted         += slot.n_decoded;
 496        t_tokens_generation        += slot.t_token_generation;
 497        t_tokens_generation_total  += slot.t_token_generation;
 498    }
 499
 500    void on_decoded(const std::vector<server_slot> & slots) {
 501        n_decode_total++;
 502        for (const auto & slot : slots) {
 503            if (slot.is_processing()) {
 504                n_busy_slots_total++;
 505            }
 506            n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
 507        }
 508    }
 509
 510    void reset_bucket() {
 511        n_prompt_tokens_processed = 0;
 512        t_prompt_processing       = 0;
 513        n_tokens_predicted        = 0;
 514        t_tokens_generation       = 0;
 515    }
 516};
 517
 518
 519//
 520// server_context_impl (private implementation)
 521//
 522
 523struct server_context_impl {
 524    friend struct server_context;
 525
 526public:
 527    // only use these pointers outside of this class:
 528    //  - when not in sleeping state
 529    //  - and, with thread-safe APIs (e.g., tokenizer calls)
 530    llama_model * model = nullptr;
 531    mtmd_context * mctx = nullptr;
 532    const llama_vocab * vocab = nullptr;
 533
 534    server_queue    queue_tasks;
 535    server_response queue_results;
 536
 537    // note: chat_params must not be refreshed upon existing sleeping state
 538    server_chat_params chat_params;
 539
 540    ~server_context_impl() {
 541        if (!sleeping) {
 542            // destroy() is already called when entering sleeping state
 543            // we don't call it again here to avoid double free
 544            destroy();
 545        }
 546    }
 547
 548private:
 549    // note: accessing these fields outside of this class is not thread-safe
 550    // use server_context methods instead
 551
 552    common_params params_base;
 553
 554    // note: keep these alive - they determine the lifetime of the model, context, etc.
 555    common_init_result_ptr llama_init;
 556
 557    llama_context * ctx = nullptr;
 558
 559    llama_batch batch {};
 560
 561    llama_model_ptr model_dft;
 562
 563    bool add_bos_token  = true;
 564
 565    int32_t n_ctx; // total context for all clients / slots
 566
 567    // slots / clients
 568    std::vector<server_slot> slots;
 569
 570    int slots_debug = 0;
 571
 572    std::unique_ptr<server_prompt_cache> prompt_cache;
 573
 574    server_metrics metrics;
 575
 576    json json_webui_settings = json::object();
 577
 578    // Necessary similarity of prompt for slot selection
 579    float slot_prompt_similarity = 0.0f;
 580
 581    std::string model_name; // name of the loaded model, to be used by API
 582
 583    bool sleeping = false;
 584
 585    void destroy() {
 586        llama_init.reset();
 587        ctx = nullptr;
 588        model = nullptr;
 589
 590        mtmd_free(mctx);
 591        mctx = nullptr;
 592
 593        // Clear any sampling context
 594        for (server_slot & slot : slots) {
 595            common_speculative_free(slot.spec);
 596            slot.spec = nullptr;
 597        }
 598
 599        llama_batch_free(batch);
 600    }
 601
 602    void handle_sleeping_state(bool new_state) {
 603        GGML_ASSERT(sleeping != new_state);
 604        if (new_state) {
 605            SRV_INF("%s", "server is entering sleeping state\n");
 606            destroy();
 607        } else {
 608            SRV_INF("%s", "server is exiting sleeping state\n");
 609            if (!load_model(params_base)) {
 610                GGML_ABORT("failed to reload model after sleeping");
 611            }
 612        }
 613        sleeping = new_state;
 614    }
 615
 616    // load the model and initialize llama_context
 617    // this may also be called to resume from sleeping state
 618    bool load_model(const common_params & params) {
 619        bool is_resume = sleeping;
 620
 621        SRV_INF("loading model '%s'\n", params.model.path.c_str());
 622
 623        params_base = params;
 624
 625        llama_init = common_init_from_params(params_base);
 626
 627        model = llama_init->model();
 628        ctx   = llama_init->context();
 629
 630        if (model == nullptr) {
 631            SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
 632            return false;
 633        }
 634
 635        vocab = llama_model_get_vocab(model);
 636
 637        n_ctx = llama_n_ctx(ctx);
 638
 639        add_bos_token = llama_vocab_get_add_bos(vocab);
 640
 641        if (params_base.speculative.has_dft()) {
 642            SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
 643
 644            const auto & params_spec = params_base.speculative;
 645
 646            auto params_dft = params_base;
 647
 648            params_dft.n_parallel   = 1;
 649            params_dft.n_ctx        = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
 650            params_dft.n_batch      = llama_n_ctx_seq(ctx);
 651            params_dft.devices      = params_spec.devices;
 652            params_dft.model        = params_spec.mparams_dft;
 653            params_dft.n_gpu_layers = params_spec.n_gpu_layers;
 654            params_dft.cache_type_k = params_spec.cache_type_k;
 655            params_dft.cache_type_v = params_spec.cache_type_v;
 656
 657            if (params_spec.cpuparams.n_threads > 0) {
 658                params_dft.cpuparams.n_threads       = params_spec.cpuparams.n_threads;
 659                params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
 660            }
 661
 662            params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
 663
 664            auto mparams_dft = common_model_params_to_llama(params_dft);
 665
 666            model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
 667            if (model_dft == nullptr) {
 668                SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
 669                return false;
 670            }
 671
 672            params_base.speculative.model_dft = model_dft.get();
 673            params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
 674        }
 675
 676        std::string & mmproj_path = params_base.mmproj.path;
 677        if (!mmproj_path.empty()) {
 678            if (!is_resume) {
 679                mtmd_helper_log_set(common_log_default_callback, nullptr);
 680            }
 681
 682            mtmd_context_params mparams = mtmd_context_params_default();
 683
 684            mparams.use_gpu          = params_base.mmproj_use_gpu;
 685            mparams.print_timings    = false;
 686            mparams.n_threads        = params_base.cpuparams.n_threads;
 687            mparams.flash_attn_type  = params_base.flash_attn_type;
 688            mparams.warmup           = params_base.warmup;
 689            mparams.image_min_tokens = params_base.image_min_tokens;
 690            mparams.image_max_tokens = params_base.image_max_tokens;
 691
 692            mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
 693            if (mctx == nullptr) {
 694                SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
 695                return false;
 696            }
 697            SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
 698
 699            if (params_base.ctx_shift) {
 700                params_base.ctx_shift = false;
 701                SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
 702            }
 703
 704            if (params_base.n_cache_reuse) {
 705                params_base.n_cache_reuse = 0;
 706                SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
 707            }
 708
 709            if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
 710                params_base.speculative.type =  COMMON_SPECULATIVE_TYPE_NONE;
 711                SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
 712            }
 713        }
 714
 715        if (!llama_memory_can_shift(llama_get_memory(ctx))) {
 716            if (params_base.ctx_shift) {
 717                params_base.ctx_shift = false;
 718                SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
 719            }
 720
 721            if (params_base.n_cache_reuse) {
 722                params_base.n_cache_reuse = 0;
 723                SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
 724            }
 725        }
 726
 727        // Necessary similarity of prompt for slot selection
 728        slot_prompt_similarity = params_base.slot_prompt_similarity;
 729
 730        // setup slots
 731        SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
 732
 733        const int n_ctx_train = llama_model_n_ctx_train(model);
 734
 735        int n_ctx_slot = llama_n_ctx_seq(ctx);
 736        if (n_ctx_slot > n_ctx_train) {
 737            SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
 738            n_ctx_slot = n_ctx_train;
 739        }
 740
 741        slots.clear();
 742
 743        const bool can_spec = common_speculative_is_compat(ctx);
 744        if (!can_spec) {
 745            SRV_WRN("%s", "speculative decoding not supported by this context\n");
 746        }
 747
 748        // initialize slots
 749        for (int i = 0; i < params_base.n_parallel; i++) {
 750            server_slot slot;
 751
 752            slot.id    = i;
 753            slot.ctx   = ctx;
 754            slot.n_ctx = n_ctx_slot;
 755
 756            slot.mctx                   = mctx;
 757            slot.prompt.tokens.has_mtmd = mctx != nullptr;
 758
 759            // try speculative decoding
 760            if (can_spec) {
 761                slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
 762                if (slot.spec) {
 763                    if (mctx) {
 764                        SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
 765                        return false;
 766                    }
 767                    SLT_INF(slot, "%s", "speculative decoding context initialized\n");
 768                } else {
 769                    SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
 770                }
 771            }
 772
 773            SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
 774
 775            slot.callback_on_release = [this](int id_slot) {
 776                queue_tasks.pop_deferred_task(id_slot);
 777            };
 778
 779            slot.reset();
 780
 781            slots.push_back(std::move(slot));
 782        }
 783
 784        {
 785            const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
 786            slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
 787
 788            if (slots_debug) {
 789                SRV_WRN("slots debug = %d\n", slots_debug);
 790            }
 791        }
 792
 793        // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
 794        // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
 795        {
 796            const int32_t n_batch = llama_n_batch(ctx);
 797            batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
 798        }
 799
 800        if (params_base.cache_ram_mib != 0) {
 801            if (params_base.cache_ram_mib < 0) {
 802                SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
 803            } else {
 804                SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
 805            }
 806            SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
 807
 808            prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
 809        } else {
 810            SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
 811        }
 812        SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
 813
 814        if (!params_base.model_alias.empty()) {
 815            // user explicitly specified model name
 816            model_name = params_base.model_alias;
 817        } else if (!params_base.model.name.empty()) {
 818            // use model name in registry format (for models in cache)
 819            model_name = params_base.model.name;
 820        } else {
 821            // fallback: derive model name from file name
 822            auto model_path = std::filesystem::path(params_base.model.path);
 823            model_name = model_path.filename().string();
 824        }
 825
 826        if (!is_resume) {
 827            return init();
 828        }
 829
 830        return true;
 831    }
 832
 833    // unlike load_model(), this is only called once during initialization
 834    bool init() {
 835        GGML_ASSERT(ctx != nullptr);
 836        GGML_ASSERT(model != nullptr);
 837        GGML_ASSERT(!sleeping);
 838
 839        // wiring up server queues
 840        queue_tasks.on_new_task([this](server_task && task) {
 841            process_single_task(std::move(task));
 842        });
 843        queue_tasks.on_update_slots([this]() {
 844            update_slots();
 845        });
 846        queue_tasks.on_sleeping_state([this](bool sleeping) {
 847            handle_sleeping_state(sleeping);
 848        });
 849
 850        metrics.init();
 851
 852        // populate webui settings
 853        {
 854            if (!params_base.webui_config_json.empty()) {
 855                try {
 856                    json_webui_settings = json::parse(params_base.webui_config_json);
 857                } catch (const std::exception & e) {
 858                    SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
 859                    return false;
 860                }
 861            }
 862        }
 863
 864        // populate chat template params
 865        {
 866            common_chat_templates_ptr chat_templates;
 867
 868            try {
 869                chat_templates = common_chat_templates_init(model, params_base.chat_template);
 870
 871                LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
 872                    common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
 873
 874            } catch (const std::exception & e) {
 875                SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
 876                SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__);
 877                SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__);
 878                return false;
 879            }
 880
 881            // thinking is enabled if:
 882            // 1. It's not explicitly disabled (reasoning_budget == 0)
 883            // 2. The chat template supports it
 884            const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
 885            SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
 886
 887            chat_params = {
 888                /* use_jinja             */ params_base.use_jinja,
 889                /* prefill_assistant     */ params_base.prefill_assistant,
 890                /* reasoning_format      */ params_base.reasoning_format,
 891                /* chat_template_kwargs  */ params_base.default_template_kwargs,
 892                /* tmpls                 */ std::move(chat_templates),
 893                /* allow_image           */ mctx ? mtmd_support_vision(mctx) : false,
 894                /* allow_audio           */ mctx ? mtmd_support_audio (mctx) : false,
 895                /* enable_thinking       */ enable_thinking,
 896                /* media_path            */ params_base.media_path,
 897            };
 898        }
 899
 900        return true;
 901    }
 902
 903    server_slot * get_slot_by_id(int id_slot) {
 904        // note: allow id_slot to be out of bounds (wrap around)
 905        id_slot = id_slot % slots.size();
 906
 907        for (server_slot & slot : slots) {
 908            if (slot.id == id_slot) {
 909                return &slot;
 910            }
 911        }
 912
 913        return nullptr;
 914    }
 915
 916    server_slot * get_available_slot(const server_task & task) {
 917        server_slot * ret = nullptr;
 918
 919        bool update_cache = false;
 920
 921        // find the slot that has at least n% prompt similarity
 922        if (ret == nullptr && slot_prompt_similarity != 0.0f) {
 923            float sim_best = 0;
 924
 925            for (server_slot & slot : slots) {
 926                // skip the slot if it is not available
 927                if (slot.is_processing()) {
 928                    continue;
 929                }
 930
 931                const auto & tokens = slot.prompt.tokens;
 932
 933                // skip the slot if it does not contains cached tokens
 934                if (tokens.empty()) {
 935                    continue;
 936                }
 937
 938                // fraction of the Longest Common Prefix length with respect to the input prompt length
 939                const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
 940
 941                // select the current slot if the criteria match
 942                if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
 943                    sim_best = sim_cur;
 944
 945                    ret = &slot;
 946                }
 947            }
 948
 949            if (ret != nullptr) {
 950                const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
 951
 952                SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
 953                        sim_best, slot_prompt_similarity, f_keep);
 954
 955                // if we are about to lose a large portion of the existing context - save it in the prompt cache
 956                if (f_keep < 0.5f) {
 957                    update_cache = true;
 958                }
 959            }
 960        }
 961
 962        // find the slot that has been least recently used
 963        if (ret == nullptr) {
 964            int64_t t_last = -1;
 965
 966            for (server_slot & slot : slots) {
 967                // skip the slot if it is not available
 968                if (slot.is_processing()) {
 969                    continue;
 970                }
 971
 972                // select the current slot if the criteria match
 973                if (!ret || slot.t_last_used <= t_last) {
 974                    t_last = slot.t_last_used;
 975                    ret = &slot;
 976                }
 977            }
 978
 979            if (ret != nullptr) {
 980                SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
 981
 982                update_cache = true;
 983            }
 984        }
 985
 986        if (ret) {
 987            const auto & tokens = ret->prompt.tokens;
 988
 989            update_cache = update_cache && prompt_cache;
 990
 991            // cache prompts only for completion tasks
 992            update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
 993
 994            // don't update the cache if the slot's context is empty
 995            update_cache = update_cache && tokens.size() > 0;
 996
 997            // TODO: mtmd does not support prompt cache
 998            update_cache = update_cache && (ret->mctx == nullptr);
 999
1000            if (update_cache) {
1001                SRV_WRN("%s", "updating prompt cache\n");
1002
1003                const int64_t t_start = ggml_time_us();
1004
1005                ret->prompt_save(*prompt_cache);
1006
1007                if (!ret->prompt_load(*prompt_cache, task.tokens)) {
1008                    ret->prompt_clear(false);
1009                }
1010
1011                prompt_cache->update();
1012
1013                SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
1014            }
1015        }
1016
1017        return ret;
1018    }
1019
1020    // return true if at least one slot has been cleared
1021    // TODO: improve logic
1022    //       - smarter decision which slot to clear (LRU or longest prompt?)
1023    //       - move slot to level 2 cache instead of removing?
1024    //       - instead of purging, try to store and resume later?
1025    bool try_clear_idle_slots() {
1026        bool res = false;
1027
1028        if (!params_base.kv_unified) {
1029            return res;
1030        }
1031
1032        for (auto & slot : slots) {
1033            if (slot.is_processing()) {
1034                continue;
1035            }
1036
1037            if (slot.prompt.n_tokens() > 0) {
1038                SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
1039
1040                slot.prompt_clear(false);
1041
1042                res = true;
1043
1044                // clear slots one by one
1045                break;
1046            }
1047        }
1048
1049        return res;
1050    }
1051
1052    std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
1053        std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
1054        for (size_t i = 0; i < output.size(); ++i) {
1055            auto it = config.find(i);
1056            if (it != config.end()) {
1057                output[i].scale = it->second;
1058            } else {
1059                output[i].scale = 0.0f;
1060            }
1061        }
1062        return output;
1063    }
1064
1065    bool launch_slot_with_task(server_slot & slot, server_task && task) {
1066        // process per-request lora adapters
1067        if (!task.params.lora.empty()) {
1068            auto task_loras = construct_lora_list(task.params.lora);
1069            if (!are_lora_equal(task_loras, slot.lora)) {
1070                // if lora has changed, check to see if the cache should be cleared
1071                if (lora_should_clear_cache(slot.lora, task_loras)) {
1072                    SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
1073                    slot.prompt.tokens.clear();
1074                } else {
1075                    SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task_loras.size());
1076                }
1077                slot.lora = task_loras;
1078            }
1079        } else {
1080            slot.lora = params_base.lora_adapters;
1081        }
1082
1083        // if using alora, make sure it's only a single one requested and active
1084        size_t alora_invocation_start = task.tokens.size();
1085        if (lora_all_alora(slot.lora)) {
1086            const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
1087            // TODO: This will error out if a user requests two aloras, but only
1088            // provides the activation string for one. We could, instead search
1089            // for all requested alora activation strings and then either keep
1090            // only the last one, or reject if multiple are found.
1091            if (enabled_ids.size() != 1) {
1092                send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
1093                return false;
1094            }
1095            const auto & lora = slot.lora[enabled_ids[0]].ptr;
1096
1097            // get the pointer and count for the invocation tokens
1098            const uint64_t      n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
1099            const llama_token * invocation_tokens   = llama_adapter_get_alora_invocation_tokens  (lora);
1100
1101            // scan backwards through the prompt tokens to find the last
1102            // occurrence of the invocation sequence
1103            int match_idx = static_cast<int>(n_invocation_tokens) - 1;
1104            for (int i = task.tokens.size() - 1; i >= 0; --i) {
1105                // the token in this position matches the next token to find in
1106                // the invocation sequence
1107                if (task.tokens[i] == invocation_tokens[match_idx]) {
1108                    // if it's a full match, we've found the start
1109                    if (match_idx == 0) {
1110                        alora_invocation_start = i;
1111                        break;
1112                    }
1113                    // otherwise, check the next token in the sequence
1114                    --match_idx;
1115                } else {
1116                    // no match in this position, so start looking over again
1117                    match_idx = static_cast<int>(n_invocation_tokens) - 1;
1118                }
1119            }
1120
1121            // if the activation string is not found, disable the alora
1122            if (alora_invocation_start == task.tokens.size()) {
1123                SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
1124                slot.lora[enabled_ids[0]].scale = 0.0f;
1125            } else {
1126                SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
1127                slot.alora_invocation_start = alora_invocation_start;
1128            }
1129        }
1130
1131        if (!task.tokens.validate(ctx)) {
1132            send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
1133            return false;
1134        }
1135
1136        SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
1137
1138        // initialize samplers
1139        if (task.need_sampling()) {
1140            slot.smpl.reset(common_sampler_init(model, task.params.sampling));
1141
1142            if (slot.smpl == nullptr) {
1143                // for now, the only error that may happen here is invalid grammar
1144                send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
1145                return false;
1146            }
1147
1148            const bool need_logits = task.params.sampling.n_probs > 0;
1149
1150            bool backend_sampling = true;
1151
1152            backend_sampling &= task.params.sampling.backend_sampling;
1153
1154            // TODO: speculative decoding requires multiple samples per batch - not supported yet
1155            backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
1156
1157            // TODO: getting post/pre sampling logits is not yet supported with backend sampling
1158            backend_sampling &= !need_logits;
1159
1160            // TODO: tmp until backend sampling is fully implemented
1161            if (backend_sampling) {
1162                llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
1163            } else {
1164                llama_set_sampler(ctx, slot.id, nullptr);
1165            }
1166
1167            SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
1168        } else {
1169            slot.smpl.reset();
1170        }
1171
1172        slot.task = std::make_unique<const server_task>(std::move(task));
1173
1174        slot.state = slot.task->is_child()
1175            ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
1176            : SLOT_STATE_STARTED;
1177
1178        SLT_INF(slot, "processing task, is_child = %d\n", slot.task->is_child());
1179        return true;
1180    }
1181
1182    bool process_token(completion_token_output & result, server_slot & slot) {
1183        // remember which tokens were sampled - used for repetition penalties during sampling
1184        const std::string token_str = result.text_to_send;
1185        slot.sampled = result.tok;
1186
1187        slot.generated_text += token_str;
1188        if (slot.task->params.return_tokens) {
1189            slot.generated_tokens.push_back(result.tok);
1190        }
1191        slot.has_next_token = true;
1192
1193        // check if there is incomplete UTF-8 character at the end
1194        bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
1195
1196        // search stop word and delete it
1197        if (!incomplete) {
1198            size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1199
1200            const std::string str_test = slot.generated_text.substr(pos);
1201            bool send_text = true;
1202
1203            size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
1204            if (stop_pos != std::string::npos) {
1205                slot.generated_text.erase(
1206                    slot.generated_text.begin() + pos + stop_pos,
1207                    slot.generated_text.end());
1208                pos = std::min(slot.n_sent_text, slot.generated_text.size());
1209            } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) {
1210                stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
1211                send_text = stop_pos == std::string::npos;
1212            }
1213
1214            // check if there is any token to predict
1215            if (send_text) {
1216                // no send the stop word in the response
1217                result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
1218                slot.n_sent_text += result.text_to_send.size();
1219                // add the token to slot queue and cache
1220            } else {
1221                result.text_to_send = "";
1222            }
1223
1224            slot.add_token(result);
1225            if (slot.task->params.stream) {
1226                send_partial_response(slot, result, false);
1227            }
1228        }
1229
1230        if (incomplete) {
1231            slot.has_next_token = true;
1232        }
1233
1234        // if context shifting is disabled, make sure that we don't run out of context
1235        if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
1236            slot.truncated      = true;
1237            slot.stop           = STOP_TYPE_LIMIT;
1238            slot.has_next_token = false;
1239
1240            SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
1241                    slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
1242        }
1243
1244        // check the limits
1245        if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
1246            slot.stop           = STOP_TYPE_LIMIT;
1247            slot.has_next_token = false;
1248
1249            SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
1250        }
1251
1252        if (slot.has_new_line) {
1253            // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1254            if (slot.task->params.n_indent > 0) {
1255                // check the current indentation
1256                // TODO: improve by not doing it more than once for each new line
1257                if (slot.last_nl_pos > 0) {
1258                    size_t pos = slot.last_nl_pos;
1259
1260                    int n_indent = 0;
1261                    while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
1262                        n_indent++;
1263                        pos++;
1264                    }
1265
1266                    if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
1267                        slot.stop           = STOP_TYPE_LIMIT;
1268                        slot.has_next_token = false;
1269
1270                        // cut the last line
1271                        slot.generated_text.erase(pos, std::string::npos);
1272
1273                        SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
1274                    }
1275                }
1276
1277                // find the next new line
1278                {
1279                    const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
1280
1281                    if (pos != std::string::npos) {
1282                        slot.last_nl_pos = pos + 1;
1283                    }
1284                }
1285            }
1286        }
1287
1288        // check if there is a new line in the generated text
1289        if (result.text_to_send.find('\n') != std::string::npos) {
1290            slot.has_new_line = true;
1291
1292            // if we have seen a new line, we stop after a certain time limit, but only upon another new line
1293            if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
1294                slot.stop           = STOP_TYPE_LIMIT;
1295                slot.has_next_token = false;
1296
1297                SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
1298            }
1299        }
1300
1301        if (llama_vocab_is_eog(vocab, result.tok)) {
1302            slot.stop           = STOP_TYPE_EOS;
1303            slot.has_next_token = false;
1304
1305            SLT_DBG(slot, "%s", "stopped by EOS\n");
1306        }
1307
1308        SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
1309
1310        return slot.has_next_token; // continue
1311    }
1312
1313    void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
1314        const size_t n_probs_request = slot.task->params.sampling.n_probs;
1315
1316        if (post_sampling) {
1317            const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
1318            const size_t max_probs = cur_p->size;
1319            const size_t n_probs = std::min(max_probs, n_probs_request);
1320
1321            // set probability for sampled token
1322            for (size_t i = 0; i < max_probs; i++) {
1323                if (cur_p->data[i].id == result.tok) {
1324                    result.prob = cur_p->data[i].p;
1325                    break;
1326                }
1327            }
1328
1329            // set probability for top n_probs tokens
1330            result.probs.reserve(n_probs);
1331            for (size_t i = 0; i < n_probs; i++) {
1332                result.probs.push_back({
1333                    cur_p->data[i].id,
1334                    common_token_to_piece(ctx, cur_p->data[i].id, special),
1335                    cur_p->data[i].p
1336                });
1337            }
1338        } else {
1339            // TODO: optimize this with min-p optimization
1340            std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1341            const size_t max_probs = cur.size();
1342            const size_t n_probs = std::min(max_probs, n_probs_request);
1343
1344            // set probability for sampled token
1345            for (size_t i = 0; i < max_probs; i++) {
1346                // set probability for sampled token
1347                if (cur[i].id == result.tok) {
1348                    result.prob = cur[i].p;
1349                    break;
1350                }
1351            }
1352
1353            // set probability for top n_probs tokens
1354            result.probs.reserve(n_probs);
1355            for (size_t i = 0; i < n_probs; i++) {
1356                result.probs.push_back({
1357                    cur[i].id,
1358                    common_token_to_piece(ctx, cur[i].id, special),
1359                    cur[i].p
1360                });
1361            }
1362        }
1363    }
1364
1365    void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1366        send_error(task.id, error, type);
1367    }
1368
1369    void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1370        send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
1371    }
1372
1373    void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
1374        SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
1375
1376        if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
1377            GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
1378        }
1379
1380        auto res = std::make_unique<server_task_result_error>();
1381        res->id              = id_task;
1382        res->err_type        = type;
1383        res->err_msg         = error;
1384        res->n_prompt_tokens = n_prompt_tokens;
1385        res->n_ctx           = n_ctx;
1386
1387        queue_results.send(std::move(res));
1388    }
1389
1390    // if multimodal is enabled, send an error and return false
1391    bool check_no_mtmd(const int id_task) {
1392        if (mctx) {
1393            send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
1394            return false;
1395        }
1396        return true;
1397    }
1398
1399    void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
1400        auto res = std::make_unique<server_task_result_cmpl_partial>();
1401
1402        res->id    = slot.task->id;
1403        res->index = slot.task->index;
1404
1405        if (is_progress) {
1406            res->is_progress        = true;
1407            res->progress.total     = slot.task->n_tokens();
1408            res->progress.cache     = slot.n_prompt_tokens_cache;
1409            res->progress.processed = slot.prompt.tokens.size();
1410            res->progress.time_ms   = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
1411        } else {
1412            res->content = tkn.text_to_send;
1413            res->tokens  = { tkn.tok };
1414        }
1415
1416        res->n_decoded           = slot.n_decoded;
1417        res->n_prompt_tokens     = slot.task->n_tokens();
1418        res->post_sampling_probs = slot.task->params.post_sampling_probs;
1419
1420        res->verbose           = slot.task->params.verbose;
1421        res->res_type          = slot.task->params.res_type;
1422        res->oaicompat_model   = slot.task->params.oaicompat_model;
1423        res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
1424
1425        // populate res.probs_output
1426        if (slot.task->params.sampling.n_probs > 0) {
1427            res->prob_output = tkn; // copy the token probs
1428        }
1429
1430        // populate timings if this is final response or timings_per_token is enabled
1431        if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
1432            res->timings = slot.get_timings();
1433        }
1434
1435        queue_results.send(std::move(res));
1436    }
1437
1438    void send_final_response(server_slot & slot) {
1439        auto res = std::make_unique<server_task_result_cmpl_final>();
1440
1441        res->id      = slot.task->id;
1442        res->id_slot = slot.id;
1443
1444        res->index           = slot.task->index;
1445        // in stream mode, content and tokens are already in last partial chunk
1446        if (slot.task->params.stream) {
1447            res->content     = "";
1448            res->tokens      = llama_tokens{};
1449        } else {
1450            res->content     = std::move(slot.generated_text);
1451            res->tokens      = std::move(slot.generated_tokens);
1452        }
1453        res->timings         = slot.get_timings();
1454        res->prompt          = slot.task->tokens.detokenize(ctx, true);
1455        res->response_fields = std::move(slot.task->params.response_fields);
1456
1457        res->truncated           = slot.truncated;
1458        res->n_decoded           = slot.n_decoded;
1459        res->n_prompt_tokens     = slot.task->n_tokens();
1460        res->n_tokens_cached     = slot.prompt.n_tokens();
1461        res->has_new_line        = slot.has_new_line;
1462        res->stopping_word       = slot.stopping_word;
1463        res->stop                = slot.stop;
1464        res->post_sampling_probs = slot.task->params.post_sampling_probs;
1465
1466        res->verbose           = slot.task->params.verbose;
1467        res->stream            = slot.task->params.stream;
1468        res->include_usage     = slot.task->params.include_usage;
1469        res->res_type          = slot.task->params.res_type;
1470        res->oaicompat_model   = slot.task->params.oaicompat_model;
1471        res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
1472
1473        // populate res.probs_output
1474        if (slot.task->params.sampling.n_probs > 0) {
1475            if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
1476                const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
1477
1478                size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
1479                res->probs_output = std::vector<completion_token_output>(
1480                        slot.generated_token_probs.begin(),
1481                        slot.generated_token_probs.end() - safe_offset);
1482            } else {
1483                res->probs_output = std::vector<completion_token_output>(
1484                        slot.generated_token_probs.begin(),
1485                        slot.generated_token_probs.end());
1486            }
1487        }
1488
1489        res->generation_params = slot.task->params; // copy the parameters
1490
1491        queue_results.send(std::move(res));
1492    }
1493
1494    void send_embedding(const server_slot & slot, const llama_batch & batch) {
1495        auto res = std::make_unique<server_task_result_embd>();
1496        res->id        = slot.task->id;
1497        res->index     = slot.task->index;
1498        res->n_tokens  = slot.task->n_tokens();
1499        res->res_type  = slot.task->params.res_type;
1500
1501        const int n_embd_out = llama_model_n_embd_out(model);
1502
1503        std::vector<float> embd_res(n_embd_out, 0.0f);
1504
1505        for (int i = 0; i < batch.n_tokens; ++i) {
1506            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1507                continue;
1508            }
1509
1510            const float * embd = nullptr;
1511            if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
1512                embd = llama_get_embeddings_ith(ctx, i);
1513            } else {
1514                embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1515            }
1516
1517            if (embd == nullptr) {
1518                SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1519
1520                res->embedding.push_back(std::vector<float>(n_embd_out, 0.0f));
1521                continue;
1522            }
1523
1524            // normalize only when there is pooling
1525            if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
1526                common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
1527                res->embedding.push_back(embd_res);
1528                break;
1529            }
1530
1531            res->embedding.emplace_back(embd, embd + n_embd_out);
1532        }
1533
1534        SLT_DBG(slot, "%s", "sending embeddings\n");
1535
1536        queue_results.send(std::move(res));
1537    }
1538
1539    void send_rerank(const server_slot & slot, const llama_batch & batch) {
1540        auto res = std::make_unique<server_task_result_rerank>();
1541        res->id       = slot.task->id;
1542        res->index    = slot.task->index;
1543        res->n_tokens = slot.task->n_tokens();
1544
1545        for (int i = 0; i < batch.n_tokens; ++i) {
1546            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1547                continue;
1548            }
1549
1550            const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1551            if (embd == NULL) {
1552                embd = llama_get_embeddings_ith(ctx, i);
1553            }
1554
1555            if (embd == NULL) {
1556                SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1557
1558                res->score = -1e6;
1559                continue;
1560            }
1561
1562            res->score = embd[0];
1563        }
1564
1565        SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
1566
1567        queue_results.send(std::move(res));
1568    }
1569
1570    //
1571    // Functions to process the task
1572    //
1573
1574    // tokenize the input if it's set by CLI, return false on error
1575    bool tokenize_cli_input(server_task & task) {
1576        try {
1577            auto & prompt = task.cli_prompt;
1578            if (mctx != nullptr) {
1579                task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files);
1580            } else {
1581                task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]);
1582            }
1583            task.cli_prompt.clear();
1584            task.cli_files.clear();
1585        } catch (const std::exception & e) {
1586            send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
1587            return false;
1588        }
1589        return true;
1590    }
1591
1592    std::vector<server_slot *> get_free_slots(size_t n_slots_needed, int exclude_id_slot) {
1593        std::vector<server_slot *> free_slots;
1594        for (auto & slot : slots) {
1595            if (!slot.is_processing() && slot.id != exclude_id_slot) {
1596                free_slots.push_back(&slot);
1597            }
1598            if (free_slots.size() >= n_slots_needed) {
1599                break;
1600            }
1601        }
1602        return free_slots;
1603    }
1604
1605    // launch multiple slots for parent + child tasks
1606    bool launch_slots_with_parent_task(server_slot & parent_slot, std::vector<server_slot *> & child_slots, server_task && parent_task) {
1607        GGML_ASSERT(!parent_slot.is_processing());
1608        GGML_ASSERT(parent_task.is_parent());
1609        GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
1610
1611        int id_parent = parent_task.id;
1612
1613        SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
1614
1615        // to be called in case of failure to release all launched slots
1616        auto release_slots = [this, id_parent]() {
1617            for (auto & slot : slots) {
1618                if (slot.is_processing() && (
1619                        slot.task->id == id_parent ||
1620                        slot.task->id_parent == id_parent
1621                )) {
1622                    slot.release();
1623                }
1624            }
1625        };
1626
1627        // launch all child tasks first
1628        size_t idx = 0;
1629        GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
1630        for (auto * slot : child_slots) {
1631            int id_child = parent_task.child_tasks[idx].id;
1632            if (!launch_slot_with_task(*slot, std::move(parent_task.child_tasks[idx]))) {
1633                SRV_ERR("failed to launch slot with child task, id_task = %d\n", id_child);
1634                release_slots();
1635                return false;
1636            }
1637            idx++;
1638        }
1639
1640        // finally, launch the parent task
1641        if (!launch_slot_with_task(parent_slot, std::move(parent_task))) {
1642            SRV_ERR("failed to launch slot with task, id_task = %d\n", id_parent);
1643            release_slots();
1644            return false;
1645        }
1646
1647        return true;
1648    }
1649
1650    void process_single_task(server_task && task) {
1651        switch (task.type) {
1652            case SERVER_TASK_TYPE_COMPLETION:
1653            case SERVER_TASK_TYPE_INFILL:
1654            case SERVER_TASK_TYPE_EMBEDDING:
1655            case SERVER_TASK_TYPE_RERANK:
1656                {
1657                    // special case: if input is provided via CLI, tokenize it first
1658                    // otherwise, no need to tokenize as it's already done inside the HTTP thread
1659                    if (task.cli) {
1660                        if (!tokenize_cli_input(task)) {
1661                            break;
1662                        }
1663                    }
1664
1665                    const int id_slot = task.id_slot;
1666                    const int id_task = task.id;
1667
1668                    server_slot * slot = id_slot != -1
1669                                            ? get_slot_by_id(id_slot)
1670                                            : get_available_slot(task);
1671
1672                    //
1673                    // slot scheduling logic
1674                    //
1675
1676                    if (slot == nullptr) {
1677                        // if no slot is available, we defer this task for processing later
1678                        SRV_DBG("no slot is available, defer task, id_task = %d\n", id_task);
1679                        queue_tasks.defer(std::move(task));
1680                        break;
1681                    }
1682
1683                    if (slot->is_processing()) {
1684                        // if requested slot is unavailable, we defer this task for processing later
1685                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", id_task);
1686                        queue_tasks.defer(std::move(task));
1687                        break;
1688                    }
1689
1690                    if (task.is_parent()) {
1691                        // try getting free slots for all child tasks
1692                        size_t n_child_tasks = task.child_tasks.size();
1693                        std::vector<server_slot *> child_slots = get_free_slots(n_child_tasks, slot->id);
1694                        if (child_slots.size() < n_child_tasks) {
1695                            SRV_DBG("not enough free slots for child tasks, n_free = %zu, n_children = %zu, defer task, id_task = %d\n", child_slots.size(), n_child_tasks, id_task);
1696                            queue_tasks.defer(std::move(task));
1697                            break;
1698                        }
1699                        if (!launch_slots_with_parent_task(*slot, child_slots, std::move(task))) {
1700                            SRV_ERR("failed to launch slot with parent task, id_task = %d\n", id_task);
1701                            break; // drop the task
1702                        }
1703                    } else if (!launch_slot_with_task(*slot, std::move(task))) {
1704                        SRV_ERR("failed to launch slot with task, id_task = %d\n", id_task);
1705                        break; // drop the task
1706                    }
1707                } break;
1708            case SERVER_TASK_TYPE_CANCEL:
1709                {
1710                    // release slot linked with the task id
1711                    for (auto & slot : slots) {
1712                        if (slot.task && slot.task->id == task.id_target) {
1713                            slot.release();
1714                            break;
1715                        }
1716                    }
1717                } break;
1718            case SERVER_TASK_TYPE_NEXT_RESPONSE:
1719                {
1720                    // do nothing
1721                } break;
1722            case SERVER_TASK_TYPE_METRICS:
1723                {
1724                    json slots_data = json::array();
1725
1726                    int n_idle_slots       = 0;
1727                    int n_processing_slots = 0;
1728
1729                    for (server_slot & slot : slots) {
1730                        json slot_data = slot.to_json(slots_debug == 0);
1731
1732                        if (slot.is_processing()) {
1733                            n_processing_slots++;
1734                        } else {
1735                            n_idle_slots++;
1736                        }
1737
1738                        slots_data.push_back(slot_data);
1739                    }
1740                    SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
1741
1742                    auto res = std::make_unique<server_task_result_metrics>();
1743                    res->id                  = task.id;
1744                    res->slots_data          = std::move(slots_data);
1745                    res->n_idle_slots        = n_idle_slots;
1746                    res->n_processing_slots  = n_processing_slots;
1747                    res->n_tasks_deferred    = queue_tasks.queue_tasks_deferred_size();
1748                    res->t_start             = metrics.t_start;
1749
1750                    res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
1751                    res->t_prompt_processing_total       = metrics.t_prompt_processing_total;
1752                    res->n_tokens_predicted_total        = metrics.n_tokens_predicted_total;
1753                    res->t_tokens_generation_total       = metrics.t_tokens_generation_total;
1754
1755                    res->n_tokens_max = metrics.n_tokens_max;
1756
1757                    res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
1758                    res->t_prompt_processing       = metrics.t_prompt_processing;
1759                    res->n_tokens_predicted        = metrics.n_tokens_predicted;
1760                    res->t_tokens_generation       = metrics.t_tokens_generation;
1761
1762                    res->n_decode_total          = metrics.n_decode_total;
1763                    res->n_busy_slots_total      = metrics.n_busy_slots_total;
1764
1765                    if (task.metrics_reset_bucket) {
1766                        metrics.reset_bucket();
1767                    }
1768                    queue_results.send(std::move(res));
1769                } break;
1770            case SERVER_TASK_TYPE_SLOT_SAVE:
1771                {
1772                    if (!check_no_mtmd(task.id)) {
1773                        break;
1774                    }
1775
1776                    const int id_slot = task.slot_action.id_slot;
1777                    server_slot * slot = get_slot_by_id(id_slot);
1778                    if (slot == nullptr) {
1779                        send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1780                        break;
1781                    }
1782                    if (slot->is_processing()) {
1783                        // if requested slot is unavailable, we defer this task for processing later
1784                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1785                        queue_tasks.defer(std::move(task));
1786                        break;
1787                    }
1788
1789                    const size_t token_count = slot->prompt.tokens.size();
1790                    const int64_t t_start = ggml_time_us();
1791
1792                    std::string filename = task.slot_action.filename;
1793                    std::string filepath = task.slot_action.filepath;
1794
1795                    const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
1796                    const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
1797
1798                    const int64_t t_end = ggml_time_us();
1799                    const double t_save_ms = (t_end - t_start) / 1000.0;
1800
1801                    auto res = std::make_unique<server_task_result_slot_save_load>();
1802                    res->id       = task.id;
1803                    res->id_slot  = id_slot;
1804                    res->filename = filename;
1805                    res->is_save  = true;
1806                    res->n_tokens = token_count;
1807                    res->n_bytes  = nwrite;
1808                    res->t_ms     = t_save_ms;
1809                    queue_results.send(std::move(res));
1810                } break;
1811            case SERVER_TASK_TYPE_SLOT_RESTORE:
1812                {
1813                    if (!check_no_mtmd(task.id)) break;
1814                    const int id_slot = task.slot_action.id_slot;
1815                    server_slot * slot = get_slot_by_id(id_slot);
1816                    if (slot == nullptr) {
1817                        send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1818                        break;
1819                    }
1820                    if (slot->is_processing()) {
1821                        // if requested slot is unavailable, we defer this task for processing later
1822                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1823                        queue_tasks.defer(std::move(task));
1824                        break;
1825                    }
1826
1827                    const int64_t t_start = ggml_time_us();
1828
1829                    std::string filename = task.slot_action.filename;
1830                    std::string filepath = task.slot_action.filepath;
1831
1832                    llama_tokens tokens;
1833                    tokens.resize(slot->n_ctx);
1834                    size_t token_count = 0;
1835                    size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
1836                    if (nread == 0) {
1837                        slot->prompt.tokens.clear(); // KV may already been invalidated?
1838                        send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
1839                        break;
1840                    }
1841                    tokens.resize(token_count);
1842                    slot->prompt.tokens.clear();
1843                    slot->prompt.tokens.insert(tokens);
1844
1845                    const int64_t t_end = ggml_time_us();
1846                    const double t_restore_ms = (t_end - t_start) / 1000.0;
1847
1848                    auto res = std::make_unique<server_task_result_slot_save_load>();
1849                    res->id       = task.id;
1850                    res->id_slot  = id_slot;
1851                    res->filename = filename;
1852                    res->is_save  = false;
1853                    res->n_tokens = token_count;
1854                    res->n_bytes  = nread;
1855                    res->t_ms     = t_restore_ms;
1856                    queue_results.send(std::move(res));
1857                } break;
1858            case SERVER_TASK_TYPE_SLOT_ERASE:
1859                {
1860                    if (!check_no_mtmd(task.id)) {
1861                        break;
1862                    }
1863                    const int id_slot = task.slot_action.id_slot;
1864                    server_slot * slot = get_slot_by_id(id_slot);
1865                    if (slot == nullptr) {
1866                        send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1867                        break;
1868                    }
1869                    if (slot->is_processing()) {
1870                        // if requested slot is unavailable, we defer this task for processing later
1871                        SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1872                        queue_tasks.defer(std::move(task));
1873                        break;
1874                    }
1875
1876                    // Erase token cache
1877                    const size_t n_erased = slot->prompt.tokens.size();
1878
1879                    slot->prompt_clear(false);
1880
1881                    auto res = std::make_unique<server_task_result_slot_erase>();
1882                    res->id       = task.id;
1883                    res->id_slot  = id_slot;
1884                    res->n_erased = n_erased;
1885                    queue_results.send(std::move(res));
1886                } break;
1887            case SERVER_TASK_TYPE_GET_LORA:
1888                {
1889                    // TODO @ngxson : make lora_adapters a dedicated member of server_context
1890                    auto & loras = params_base.lora_adapters;
1891                    auto res = std::make_unique<server_task_result_get_lora>();
1892                    res->id = task.id;
1893                    for (size_t i = 0; i < loras.size(); ++i) {
1894                        auto & lora = loras[i];
1895                        std::string alora_invocation_string = "";
1896                        const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
1897                        llama_tokens alora_invocation_tokens;
1898                        if (n_alora_tokens) {
1899                            const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
1900                            for (uint64_t j = 0; j < n_alora_tokens; ++j) {
1901                                alora_invocation_string += common_token_to_piece(vocab, alora_tokens[j]);
1902                                alora_invocation_tokens.push_back(alora_tokens[j]);
1903                            }
1904                        }
1905                        res->loras.push_back(server_task_result_get_lora::lora{
1906                            lora,
1907                            alora_invocation_string,
1908                            alora_invocation_tokens,
1909                        });
1910                    }
1911                    queue_results.send(std::move(res));
1912                } break;
1913            case SERVER_TASK_TYPE_SET_LORA:
1914                {
1915                    auto new_loras = construct_lora_list(task.set_lora);
1916                    // logging
1917                    for (size_t i = 0; i < new_loras.size(); ++i) {
1918                        SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
1919                    }
1920                    // TODO @ngxson : make lora_adapters a dedicated member of server_context
1921                    params_base.lora_adapters = new_loras;
1922                    auto res = std::make_unique<server_task_result_apply_lora>();
1923                    res->id = task.id;
1924                    queue_results.send(std::move(res));
1925                } break;
1926        }
1927    }
1928
1929    void update_slots() {
1930        // check if all slots are idle
1931        {
1932            bool all_idle = true;
1933
1934            for (auto & slot : slots) {
1935                if (slot.is_processing()) {
1936                    all_idle = false;
1937                    break;
1938                }
1939            }
1940
1941            if (all_idle) {
1942                SRV_INF("%s", "all slots are idle\n");
1943
1944                return;
1945            }
1946        }
1947
1948        {
1949            SRV_DBG("%s", "posting NEXT_RESPONSE\n");
1950
1951            server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
1952            task.id = queue_tasks.get_new_id();
1953            queue_tasks.post(std::move(task));
1954        }
1955
1956        // apply context-shift if needed
1957        // TODO: simplify and improve
1958        for (server_slot & slot : slots) {
1959            if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
1960                if (!params_base.ctx_shift) {
1961                    // this check is redundant (for good)
1962                    // we should never get here, because generation should already stopped in process_token()
1963                    send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1964                    slot.release();
1965                    continue;
1966                }
1967
1968                if (mctx) {
1969                    // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
1970                    // we don't support ctx_shift because an image chunk may contains multiple tokens
1971                    GGML_ABORT("not supported by multimodal");
1972                }
1973
1974                if (slot.task->is_parent() || slot.task->is_child()) {
1975                    send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
1976                    slot.release();
1977                    continue;
1978                }
1979
1980                // Shift context
1981                int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
1982
1983                if (add_bos_token) {
1984                    n_keep += 1;
1985                }
1986
1987                n_keep = std::min(slot.n_ctx - 4, n_keep);
1988
1989                const int n_left    = slot.prompt.n_tokens() - n_keep;
1990                const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
1991
1992                SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1993
1994                llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep            , n_keep + n_discard);
1995                llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
1996
1997                // add generated tokens to cache
1998                // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
1999                {
2000                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
2001
2002                    llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
2003                    for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
2004                        new_tokens[i - n_discard] = new_tokens[i];
2005                    }
2006
2007                    new_tokens.resize(slot.prompt.tokens.size() - n_discard);
2008
2009                    slot.prompt.tokens.clear();
2010                    slot.prompt.tokens.insert(new_tokens);
2011                }
2012
2013                slot.truncated = true;
2014            }
2015        }
2016
2017        // start populating the batch for this iteration
2018        common_batch_clear(batch);
2019
2020        // track if given slot can be batched with slots already in the batch
2021        server_slot * slot_batched = nullptr;
2022
2023        auto accept_special_token = [&](server_slot & slot, llama_token token) {
2024            return params_base.special ||
2025                slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
2026        };
2027
2028        // first, add sampled tokens from any ongoing sequences
2029        for (auto & slot : slots) {
2030            if (slot.state != SLOT_STATE_GENERATING) {
2031                continue;
2032            }
2033
2034            // check if we can batch this slot with the previous one
2035            if (!slot_batched) {
2036                slot_batched = &slot;
2037            } else if (!slot_batched->can_batch_with(slot)) {
2038                continue;
2039            }
2040
2041            // generate draft tokens in speculative decoding mode
2042            // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
2043            //       perform the speculative drafting for all sequences at the same time in a single batch
2044            const int n_draft_max = slot.get_n_draft_max();
2045            if (n_draft_max > 0) {
2046                if (mctx) {
2047                    // we should never reach this, as speculative is automatically disabled if mmproj is loaded
2048                    GGML_ABORT("not supported by multimodal");
2049                }
2050
2051                const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
2052
2053                const auto & params_spec = slot.task->params.speculative;
2054
2055                llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
2056
2057                if (draft.size() > (size_t) n_draft_max) {
2058                    SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
2059                    draft.resize(n_draft_max);
2060                }
2061
2062                // add the sampled token to the batch
2063                slot.i_batch_dft.push_back(batch.n_tokens);
2064                common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
2065                slot.prompt.tokens.push_back(slot.sampled);
2066
2067                if (slot.task->params.speculative.n_min > (int) draft.size()) {
2068                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
2069                    // fallback to normal decoding
2070                    slot.i_batch = slot.i_batch_dft[0];
2071                    slot.drafted.clear();
2072                    slot.i_batch_dft.clear();
2073                } else {
2074                    // keep track of total number of drafted tokens tested
2075                    slot.n_draft_total += draft.size();
2076
2077                    // add all drafted tokens to the batch
2078                    for (size_t i = 0; i < draft.size(); i++) {
2079                        slot.i_batch_dft.push_back(batch.n_tokens);
2080                        common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
2081                        slot.prompt.tokens.push_back(draft[i]);
2082                    }
2083                    slot.drafted = std::move(draft);
2084                }
2085            } else {
2086                // no speculative decoding
2087                slot.i_batch = batch.n_tokens;
2088
2089                common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
2090
2091                slot.prompt.tokens.push_back(slot.sampled);
2092
2093                SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
2094                        slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
2095            }
2096        }
2097
2098        // process in chunks of params.n_batch
2099        int32_t n_batch  = llama_n_batch(ctx);
2100        int32_t n_ubatch = llama_n_ubatch(ctx);
2101
2102        float  alora_scale       = -1.0f;
2103        size_t alora_disabled_id = 0;
2104
2105        // next, batch any pending prompts without exceeding n_batch
2106        if (params_base.cont_batching || batch.n_tokens == 0) {
2107            for (auto & slot : slots) {
2108                if (!slot.is_processing()) {
2109                    continue;
2110                }
2111
2112                // check if we can batch this slot with the previous one
2113                if (slot_batched && !slot_batched->can_batch_with(slot)) {
2114                    continue;
2115                }
2116
2117                // check if this is a child slot
2118                if (slot.state == SLOT_STATE_WAIT_OTHER) {
2119                    SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
2120                    continue;
2121                }
2122
2123                // this slot still has a prompt to be processed
2124                if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
2125                    const auto & input_tokens = slot.task->tokens;
2126
2127                    // TODO: maybe move branch to outside of this loop in the future
2128                    if (slot.state == SLOT_STATE_STARTED) {
2129                        slot.t_start_process_prompt = ggml_time_us();
2130                        slot.t_start_generation = 0;
2131
2132                        slot.state = SLOT_STATE_PROCESSING_PROMPT;
2133
2134                        SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
2135                                slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
2136
2137                        // print prompt tokens (for debugging)
2138                        /*if (1) {
2139                            // first 16 tokens (avoid flooding logs)
2140                            for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
2141                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
2142                            }
2143                        } else {
2144                            // all
2145                            for (int i = 0; i < (int) input_tokens.size(); i++) {
2146                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
2147                            }
2148                        }*/
2149
2150                        // keep track how many tokens we can reuse from the previous state
2151                        int n_past = 0;
2152
2153                        // empty prompt passed -> release the slot and send empty response
2154                        if (input_tokens.empty()) {
2155                            SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
2156
2157                            slot.print_timings();
2158                            send_final_response(slot);
2159                            slot.release();
2160
2161                            continue;
2162                        }
2163
2164                        // TODO: support memory-less logits computation
2165                        if (slot.task->need_logits() && !llama_get_memory(ctx)) {
2166                            send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
2167                            slot.release();
2168                            continue;
2169                        }
2170
2171                        if (!slot.can_split()) {
2172                            if (slot.task->n_tokens() > n_ubatch) {
2173                                send_error(slot,
2174                                           string_format(
2175                                               "input (%d tokens) is too large to process. increase the physical batch "
2176                                               "size (current batch size: %d)",
2177                                               slot.task->n_tokens(), n_ubatch),
2178                                           ERROR_TYPE_SERVER);
2179                                slot.release();
2180                                continue;
2181                            }
2182
2183                            if (slot.task->n_tokens() > slot.n_ctx) {
2184                                send_error(
2185                                    slot,
2186                                    string_format(
2187                                        "input (%d tokens) is larger than the max context size (%d tokens). skipping",
2188                                        slot.task->n_tokens(), slot.n_ctx),
2189                                    ERROR_TYPE_EXCEED_CONTEXT_SIZE);
2190                                slot.release();
2191                                continue;
2192                            }
2193                        } else {
2194                            if (slot.task->n_tokens() >= slot.n_ctx) {
2195                                send_error(slot,
2196                                           string_format("request (%d tokens) exceeds the available context size (%d "
2197                                                         "tokens), try increasing it",
2198                                                         slot.task->n_tokens(), slot.n_ctx),
2199                                           ERROR_TYPE_EXCEED_CONTEXT_SIZE);
2200                                slot.release();
2201                                continue;
2202                            }
2203
2204                            if (slot.task->params.cache_prompt) {
2205                                // reuse any previously computed tokens that are common with the new prompt
2206                                n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
2207
2208                                // if there is an alora invoked, don't cache after the invocation start
2209                                if (slot.alora_invocation_start > 0) {
2210                                    SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
2211                                    n_past = std::min(n_past, slot.alora_invocation_start - 1);
2212                                }
2213
2214                                const auto n_cache_reuse = slot.task->params.n_cache_reuse;
2215
2216                                const bool can_cache_reuse =
2217                                    llama_memory_can_shift(llama_get_memory(ctx)) &&
2218                                    !slot.prompt.tokens.has_mtmd;
2219
2220                                if (!can_cache_reuse && n_cache_reuse > 0) {
2221                                    SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
2222                                }
2223
2224                                // reuse chunks from the cached prompt by shifting their KV cache in the new position
2225                                if (can_cache_reuse && n_cache_reuse > 0) {
2226                                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
2227
2228                                    size_t head_c = n_past; // cache
2229                                    size_t head_p = n_past; // current prompt
2230
2231                                    if (mctx) {
2232                                        // we should never reach this
2233                                        GGML_ABORT("not supported by multimodal");
2234                                    }
2235
2236                                    SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
2237
2238                                    while (head_c < slot.prompt.tokens.size() &&
2239                                           head_p < input_tokens.size()) {
2240
2241                                        size_t n_match = 0;
2242                                        while (head_c + n_match < slot.prompt.tokens.size() &&
2243                                               head_p + n_match < input_tokens.size()       &&
2244                                               slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
2245                                            n_match++;
2246                                        }
2247
2248                                        if (n_match >= (size_t) n_cache_reuse) {
2249                                            SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2250                                            //for (size_t i = head_p; i < head_p + n_match; i++) {
2251                                            //    SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2252                                            //}
2253
2254                                            const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
2255
2256                                            llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
2257                                            llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
2258
2259                                            for (size_t i = 0; i < n_match; i++) {
2260                                                slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
2261                                                n_past++;
2262                                            }
2263
2264                                            head_c += n_match;
2265                                            head_p += n_match;
2266                                        } else {
2267                                            head_c += 1;
2268                                        }
2269                                    }
2270
2271                                    SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
2272                                }
2273                            } else {
2274                                // if we don't cache the prompt, we have to remove all previous tokens
2275                                n_past = 0;
2276                            }
2277
2278                            // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
2279                            const auto n_swa = std::max(1, llama_model_n_swa(model));
2280
2281                            // the largest pos_min required for a checkpoint to be useful
2282                            const auto pos_min_thold = std::max(0, n_past - n_swa);
2283
2284                            // note: disallow with mtmd contexts for now
2285                            //       https://github.com/ggml-org/llama.cpp/issues/17043
2286                            if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
2287                                const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
2288                                if (pos_min == -1) {
2289                                    SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
2290                                    GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
2291                                }
2292
2293                                // when the prompt prefix does not match, print the tokens around the mismatch
2294                                // this is useful for debugging prompt caching
2295                                if (slots_debug) {
2296                                    const int np0 = std::max<int>(n_past - 4, 0);
2297                                    const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
2298
2299                                    std::stringstream ss0;
2300                                    std::stringstream ss1;
2301
2302                                    std::stringstream st0;
2303                                    std::stringstream st1;
2304
2305                                    ss0 << "old: ... ";
2306                                    ss1 << "new: ... ";
2307
2308                                    for (int i = np0; i < np1; i++) {
2309                                        if (i == n_past) {
2310                                            ss0 << " | ";
2311                                            ss1 << " | ";
2312                                        }
2313
2314                                        {
2315                                            const auto token = slot.prompt.tokens[i];
2316                                            const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
2317                                            ss0 << piece;
2318                                            st0 << std::setw(8) << token;
2319                                        }
2320
2321                                        {
2322                                            const auto token = slot.task->tokens[i];
2323                                            const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
2324                                            ss1 << piece;
2325                                            st1 << std::setw(8) << token;
2326                                        }
2327                                    }
2328
2329                                    SLT_WRN(slot, "%s\n", ss0.str().c_str());
2330                                    SLT_WRN(slot, "%s\n", ss1.str().c_str());
2331
2332                                    SLT_WRN(slot, "%s\n", st0.str().c_str());
2333                                    SLT_WRN(slot, "%s\n", st1.str().c_str());
2334                                }
2335
2336                                if (pos_min > pos_min_thold) {
2337                                    // TODO: support can be added in the future when corresponding vision models get released
2338                                    GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
2339
2340                                    SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
2341
2342                                    // search for a context checkpoint
2343                                    const auto it = std::find_if(
2344                                        slot.prompt.checkpoints.rbegin(),
2345                                        slot.prompt.checkpoints.rend(),
2346                                        [&](const auto & cur) {
2347                                            // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
2348                                            return cur.pos_min < pos_min_thold;
2349                                        }
2350                                    );
2351
2352                                    bool do_reset = it == slot.prompt.checkpoints.rend();
2353
2354                                    if (!do_reset) {
2355                                        // restore the context checkpoint
2356                                        const size_t checkpoint_size = it->data.size();
2357                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2358
2359                                        if (n != checkpoint_size) {
2360                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
2361                                            do_reset = true;
2362                                            //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
2363                                        } else {
2364                                            n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
2365                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
2366                                        }
2367                                    }
2368
2369                                    if (do_reset) {
2370                                        SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
2371                                                "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
2372                                        n_past = 0;
2373                                    }
2374                                }
2375                            }
2376
2377                            {
2378                                // erase any checkpoints with pos_min > pos_min_thold
2379                                for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
2380                                    const auto & cur = *it;
2381                                    if (cur.pos_min > pos_min_thold) {
2382                                        SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
2383                                        it = slot.prompt.checkpoints.erase(it);
2384                                    } else {
2385                                        ++it;
2386                                    }
2387                                }
2388                            }
2389                        }
2390
2391                        // [TAG_PROMPT_LOGITS]
2392                        if (n_past == slot.task->n_tokens() && n_past > 0) {
2393                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
2394                            n_past--;
2395                            SLT_WRN(slot, "n_past was set to %d\n", n_past);
2396                        }
2397
2398                        slot.n_prompt_tokens_cache     = n_past;
2399                        slot.n_prompt_tokens_processed = 0;
2400
2401                        slot.prompt.tokens.keep_first(n_past);
2402
2403                        // send initial 0% progress update if needed
2404                        // this is to signal the client that the request has started processing
2405                        if (slot.task->params.stream && slot.task->params.return_progress) {
2406                            send_partial_response(slot, {}, true);
2407                        }
2408                    }
2409
2410                    if (!slot.can_split()) {
2411                        // cannot fit the prompt in the current batch - will try next iter
2412                        if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
2413                            continue;
2414                        }
2415                    }
2416
2417                    // truncate any tokens that are beyond n_past for this slot
2418                    const llama_pos p0 = slot.prompt.tokens.pos_next();
2419
2420                    SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
2421
2422                    if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
2423                        SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
2424
2425                        slot.prompt_clear(true);
2426
2427                        // there is no common part left
2428                        slot.n_prompt_tokens_cache = 0;
2429                    }
2430
2431                    // check if we should process the image
2432                    if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
2433                        // process the image
2434                        size_t n_tokens_out = 0;
2435                        int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
2436                        if (res != 0) {
2437                            SLT_ERR(slot, "failed to process image, res = %d\n", res);
2438                            send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
2439                            slot.release();
2440                            continue;
2441                        }
2442
2443                        slot.n_prompt_tokens_processed += n_tokens_out;
2444
2445                        // add the image chunk to cache
2446                        {
2447                            const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
2448                            slot.prompt.tokens.push_back(chunk.get()); // copy
2449                        }
2450                    }
2451
2452                    // If using an alora, there may be uncached tokens that come
2453                    // before the invocation sequence. When this happens, the
2454                    // tokens before the invocation sequence need to be
2455                    // processed without the adapter in a separate batch, then
2456                    // the adapter needs to be enabled for the remaining tokens.
2457                    if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
2458                        SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
2459                        const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
2460                        GGML_ASSERT(enabled_loras.size() == 1);
2461                        alora_scale = slot.lora[enabled_loras[0]].scale;
2462                        slot.lora[enabled_loras[0]].scale = 0.0f;
2463                        alora_disabled_id = enabled_loras[0];
2464                    }
2465
2466                    bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
2467
2468                    // make checkpoints only for completion tasks
2469                    do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
2470
2471                    // make a checkpoint of the parts of the memory that cannot be rolled back.
2472                    // checkpoints are created only if:
2473                    // - the model uses SWA and we are not using `swa_full`
2474                    // - the model architecture is marked as recurrent or hybrid
2475                    //
2476                    // TODO: try to make this conditional on the context or the memory module, instead of the model type
2477                    do_checkpoint = do_checkpoint && (
2478                            llama_model_is_recurrent(model) ||
2479                            llama_model_is_hybrid(model) ||
2480                            (llama_model_n_swa(model) > 0 && !params_base.swa_full)
2481                            );
2482
2483                    // add prompt tokens for processing in the current batch
2484                    while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
2485                        // get next token to process
2486                        llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
2487                        if (cur_tok == LLAMA_TOKEN_NULL) {
2488                            break; // end of text chunk
2489                        }
2490
2491                        // if this is an alora request with pre-invocation
2492                        // tokens that are not cached, we need to stop filling
2493                        // this batch at those pre-invocation tokens.
2494                        if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
2495                            SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
2496                            break;
2497                        }
2498
2499                        // embedding requires all tokens in the batch to be output
2500                        common_batch_add(batch,
2501                            cur_tok,
2502                            slot.prompt.tokens.pos_next(),
2503                            { slot.id },
2504                            slot.task->need_embd());
2505                        slot.prompt.tokens.push_back(cur_tok);
2506
2507                        slot.n_prompt_tokens_processed++;
2508
2509                        // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
2510                        const int n_last = std::min(n_batch, 512);
2511                        if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
2512                            break;
2513                        }
2514                    }
2515
2516                    // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
2517
2518                    SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
2519
2520                    // entire prompt has been processed
2521                    if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
2522                        slot.state = SLOT_STATE_DONE_PROMPT;
2523
2524                        GGML_ASSERT(batch.n_tokens > 0);
2525
2526                        // extract the logits only for the last token
2527                        batch.logits[batch.n_tokens - 1] = true;
2528
2529                        slot.n_decoded = 0;
2530                        slot.i_batch   = batch.n_tokens - 1;
2531
2532                        SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
2533
2534                        slot.init_sampler();
2535
2536                        const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
2537                        const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
2538
2539                        // no need for empty or small checkpoints
2540                        do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
2541
2542                        // no need to create checkpoints that are too close together
2543                        do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
2544
2545                        if (do_checkpoint) {
2546                            while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
2547                                // make room for the new checkpoint, if needed
2548                                const auto & cur = slot.prompt.checkpoints.front();
2549
2550                                SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
2551                                        cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
2552
2553                                slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
2554                            }
2555
2556                            const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2557
2558                            auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
2559                                /*.pos_min = */ pos_min,
2560                                /*.pos_max = */ pos_max,
2561                                /*.data    = */ std::vector<uint8_t>(checkpoint_size),
2562                            });
2563
2564                            llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2565
2566                            SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
2567                                    (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
2568                        }
2569                    }
2570                }
2571
2572                if (!slot_batched) {
2573                    slot_batched = &slot;
2574                }
2575
2576                if (batch.n_tokens >= n_batch) {
2577                    break;
2578                }
2579            }
2580        }
2581
2582        SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
2583
2584        if (slot_batched) {
2585            // apply lora, only need to do it once per batch
2586            common_set_adapter_lora(ctx, slot_batched->lora);
2587
2588            // if the lora is temporarily disabled for an alora, re-enable it
2589            // for next time
2590            if (alora_scale > 0.0f) {
2591                SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
2592                slot_batched->lora[alora_disabled_id].scale = alora_scale;
2593            }
2594
2595            llama_set_embeddings(ctx, slot_batched->task->need_embd());
2596        }
2597
2598        if (batch.n_tokens == 0) {
2599            SRV_WRN("%s", "no tokens to decode\n");
2600        }
2601
2602        int32_t i_next = 0;
2603
2604        // process the created batch of tokens
2605        for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
2606            const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
2607
2608            llama_batch batch_view = {
2609                n_tokens,
2610                batch.token    + i,
2611                nullptr,
2612                batch.pos      + i,
2613                batch.n_seq_id + i,
2614                batch.seq_id   + i,
2615                batch.logits   + i,
2616            };
2617
2618            const int ret = llama_decode(ctx, batch_view);
2619
2620            metrics.on_decoded(slots);
2621
2622            if (ret != 0) {
2623                {
2624                    std::string err;
2625
2626                    if (n_batch == 1 && ret == 1) {
2627                        // TODO: try to terminate only the largest active slot/sequence and continue with the rest
2628                        //       need to remove the tokens from the current batch too
2629                        err = "Context size has been exceeded.";
2630                    }
2631
2632                    if (ret == -1) {
2633                        err = "Invalid input batch.";
2634                    }
2635
2636                    if (ret < -1) {
2637                        // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
2638                        err = "Compute error.";
2639                    }
2640
2641                    // TODO: handle ret == 2 (abort) when we start aborting
2642
2643                    if (!err.empty()) {
2644                        SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
2645
2646                        for (auto & slot : slots) {
2647                            if (slot.is_processing()) {
2648                                send_error(slot, err);
2649                                slot.release();
2650
2651                                // note: it's complicated to keep track of how much of the current batch has been
2652                                //       processed before the error occurred, so we simply clear the entire context
2653                                slot.prompt_clear(false);
2654                            }
2655                        }
2656
2657                        break;
2658                    }
2659                }
2660
2661                // retry with half the batch size to try to find a free slot in the KV cache
2662                if (!try_clear_idle_slots()) {
2663                    n_batch /= 2;
2664                }
2665
2666                SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
2667
2668                continue; // continue loop of n_batch
2669            }
2670
2671            // move the head of the batch forward with the number of tokens we just processed
2672            i_next = i + n_tokens;
2673
2674            // on successful decode, restore the original batch size
2675            n_batch = llama_n_batch(ctx);
2676
2677            // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
2678            for (auto & slot : slots) {
2679                if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
2680                    std::vector<server_slot *> children;
2681                    for (auto & other : slots) {
2682                        if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
2683                            children.push_back(&other);
2684                        }
2685                    }
2686
2687                    // all children slots should already launched by launch_slots_with_parent_task()
2688                    // copy state to the child slots
2689                    for (auto & child : children) {
2690                        SLT_INF(slot, " - copying state to child %d\n", child->id);
2691
2692                        GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
2693
2694                        slot.copy_state_to(*child);
2695                        child->state = SLOT_STATE_DONE_PROMPT;
2696                    }
2697                }
2698            }
2699
2700            for (auto & slot : slots) {
2701                // optionally send prompt processing progress
2702                if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
2703                    if (slot.task->params.stream && slot.task->params.return_progress) {
2704                        send_partial_response(slot, {}, true);
2705                    }
2706                }
2707
2708                if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2709                    continue; // continue loop of slots
2710                }
2711
2712                if (slot.state == SLOT_STATE_DONE_PROMPT) {
2713                    if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
2714                        // prompt evaluated for embedding
2715                        send_embedding(slot, batch_view);
2716                        slot.release();
2717                        slot.i_batch = -1;
2718                        continue; // continue loop of slots
2719                    }
2720
2721                    if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
2722                        send_rerank(slot, batch_view);
2723                        slot.release();
2724                        slot.i_batch = -1;
2725                        continue; // continue loop of slots
2726                    }
2727
2728                    GGML_ASSERT(slot.task->need_sampling());
2729
2730                    // prompt evaluated for next-token prediction
2731                    slot.state = SLOT_STATE_GENERATING;
2732
2733                    if (slot.can_speculate()) {
2734                        common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
2735                    }
2736                } else if (slot.state != SLOT_STATE_GENERATING) {
2737                    continue; // continue loop of slots
2738                }
2739
2740                if (slot.i_batch_dft.size() > 0) {
2741                    continue; // sample using speculative decoding
2742                }
2743
2744                const int tok_idx = slot.i_batch - i;
2745
2746                llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
2747
2748                slot.i_batch = -1;
2749
2750                common_sampler_accept(slot.smpl.get(), id, true);
2751
2752                // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
2753                const int64_t t_current = ggml_time_us();
2754
2755                slot.n_decoded += 1;
2756
2757                if (slot.n_decoded == 1) {
2758                    slot.t_start_generation = t_current;
2759                    slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2760                    metrics.on_prompt_eval(slot);
2761                }
2762
2763                slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
2764
2765                completion_token_output result;
2766                result.tok          = id;
2767                result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
2768                result.prob         = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
2769
2770                if (slot.task->params.sampling.n_probs > 0) {
2771                    populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
2772                }
2773
2774                if (!process_token(result, slot)) {
2775                    // release slot because of stop condition
2776                    slot.print_timings();
2777                    send_final_response(slot);
2778                    metrics.on_prediction(slot);
2779                    slot.release();
2780
2781                    continue;
2782                }
2783            }
2784
2785            // speculative decoding - main model sample and accept
2786            for (auto & slot : slots) {
2787                if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
2788                    continue;
2789                }
2790
2791                const size_t n_draft = slot.drafted.size();
2792
2793                // the accepted tokens from the speculation
2794                const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
2795                slot.i_batch_dft.clear();
2796                slot.drafted.clear();
2797
2798                const int64_t t_current = ggml_time_us();
2799
2800                slot.n_decoded += ids.size();
2801
2802                slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
2803
2804                // update how many tokens out of those tested were accepted
2805                slot.n_draft_accepted += ids.size() - 1;
2806
2807                // inform the speculative decoding about the number of accepted tokens
2808                common_speculative_accept(slot.spec, ids.size() - 1);
2809
2810                // rollback to the state before sampling the draft tokens
2811                slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
2812
2813                // add accepted tokens to the prompt
2814                slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
2815                slot.sampled = ids.back(); // last accepted token
2816
2817                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
2818
2819                for (size_t i = 0; i < ids.size(); ++i) {
2820                    completion_token_output result;
2821
2822                    result.tok          = ids[i];
2823                    result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
2824                    result.prob         = 1.0f; // set later
2825
2826                    // TODO: set result.probs
2827
2828                    if (!process_token(result, slot)) {
2829                        slot.print_timings();
2830                        send_final_response(slot);
2831                        metrics.on_prediction(slot);
2832                        slot.release();
2833
2834                        break;
2835                    }
2836                }
2837
2838                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
2839            }
2840        }
2841
2842        SRV_DBG("%s", "run slots completed\n");
2843    }
2844
2845    int get_slot_n_ctx() {
2846        return slots.back().n_ctx;
2847    }
2848
2849    server_response_reader get_response_reader() {
2850        return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
2851    }
2852};
2853
2854//
2855// server_context (public API)
2856//
2857
2858server_context::server_context() : impl(new server_context_impl()) {}
2859server_context::~server_context() = default;
2860
2861bool server_context::load_model(const common_params & params) {
2862    return impl->load_model(params);
2863}
2864
2865void server_context::start_loop() {
2866    auto & params = impl->params_base;
2867    impl->queue_tasks.start_loop(params.sleep_idle_seconds * 1000);
2868}
2869
2870void server_context::terminate() {
2871    impl->queue_tasks.terminate();
2872}
2873
2874llama_context * server_context::get_llama_context() const {
2875    return impl->ctx;
2876}
2877
2878server_response_reader server_context::get_response_reader() {
2879    return impl->get_response_reader();
2880}
2881
2882server_context_meta server_context::get_meta() const {
2883    auto bos_id = llama_vocab_bos(impl->vocab);
2884    auto eos_id = llama_vocab_eos(impl->vocab);
2885    auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
2886    auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : "";
2887
2888    return server_context_meta {
2889        /* build_info             */ build_info,
2890        /* model_name             */ impl->model_name,
2891        /* model_path             */ impl->params_base.model.path,
2892        /* has_mtmd               */ impl->mctx != nullptr,
2893        /* has_inp_image          */ impl->chat_params.allow_image,
2894        /* has_inp_audio          */ impl->chat_params.allow_audio,
2895        /* json_webui_settings    */ impl->json_webui_settings,
2896        /* slot_n_ctx             */ impl->get_slot_n_ctx(),
2897        /* pooling_type           */ llama_pooling_type(impl->ctx),
2898
2899        /* chat_params            */ impl->chat_params,
2900        /* chat_template_caps     */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
2901
2902        /* bos_token_str          */ bos_token_str,
2903        /* eos_token_str          */ eos_token_str,
2904        /* fim_pre_token          */ llama_vocab_fim_pre(impl->vocab),
2905        /* fim_sub_token          */ llama_vocab_fim_suf(impl->vocab),
2906        /* fim_mid_token          */ llama_vocab_fim_mid(impl->vocab),
2907
2908        /* model_vocab_type       */ llama_vocab_type(impl->vocab),
2909        /* model_vocab_n_tokens   */ llama_vocab_n_tokens(impl->vocab),
2910        /* model_n_ctx_train      */ llama_model_n_ctx_train(impl->model),
2911        /* model_n_embd_inp       */ llama_model_n_embd(impl->model),
2912        /* model_n_params         */ llama_model_n_params(impl->model),
2913        /* model_size             */ llama_model_size(impl->model),
2914    };
2915}
2916
2917
2918
2919// generator-like API for HTTP response generation
2920// may have bypass_sleep = true if the task does not use ctx_server
2921struct server_res_generator : server_http_res {
2922    server_response_reader rd;
2923    server_res_generator(server_queue & queue_tasks, server_response & queue_results, int sleep_idle_seconds, bool bypass_sleep = false)
2924            : rd(queue_tasks, queue_results, HTTP_POLLING_SECONDS) {
2925        // fast path in case sleeping is disabled
2926        bypass_sleep |= sleep_idle_seconds < 0;
2927        if (!bypass_sleep) {
2928            queue_tasks.wait_until_no_sleep();
2929        }
2930    }
2931    void ok(const json & response_data) {
2932        status = 200;
2933        data = safe_json_to_str(response_data);
2934    }
2935    void error(const json & error_data) {
2936        status = json_value(error_data, "code", 500);
2937        data = safe_json_to_str({{ "error", error_data }});
2938    }
2939};
2940
2941
2942
2943//
2944// server_routes
2945//
2946
2947std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
2948            const server_http_req & req,
2949            server_task_type type,
2950            const json & data,
2951            const std::vector<raw_buffer> & files,
2952            task_response_type res_type) {
2953    GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
2954
2955    auto res = create_response();
2956    auto completion_id = gen_chatcmplid();
2957    auto & rd = res->rd;
2958
2959    try {
2960        std::vector<server_task> tasks;
2961
2962        const auto & prompt = data.at("prompt");
2963        // TODO: this log can become very long, put it behind a flag or think about a more compact format
2964        //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
2965
2966        // process prompt
2967        std::vector<server_tokens> inputs;
2968
2969        if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
2970            // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
2971            inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
2972        } else {
2973            // Everything else, including multimodal completions.
2974            inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
2975        }
2976
2977        // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
2978
2979        for (size_t i = 0; i < inputs.size(); i++) {
2980            server_task task = server_task(type);
2981
2982            task.id = rd.get_new_id();
2983
2984            task.tokens = std::move(inputs[i]);
2985            task.params = server_task::params_from_json_cmpl(
2986                    ctx_server.vocab,
2987                    params,
2988                    meta->slot_n_ctx,
2989                    data);
2990            task.id_slot = json_value(data, "id_slot", -1);
2991
2992            // OAI-compat
2993            task.params.res_type          = res_type;
2994            task.params.oaicompat_cmpl_id = completion_id;
2995            task.params.oaicompat_model   = meta->model_name;
2996
2997            // prepare child tasks
2998            if (task.params.n_cmpl > 1) {
2999                int n_children = task.params.n_cmpl - 1;
3000                for (int j = 0; j < n_children; j++) {
3001                    task.add_child(task.id, rd.get_new_id());
3002                }
3003            }
3004
3005            tasks.push_back(std::move(task));
3006        }
3007
3008        rd.post_tasks(std::move(tasks));
3009    } catch (const std::exception & e) {
3010        res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
3011        return res;
3012    }
3013
3014    bool stream = json_value(data, "stream", false);
3015
3016    if (!stream) {
3017        // non-stream, wait for the results
3018        auto all_results = rd.wait_for_all(req.should_stop);
3019        if (all_results.is_terminated) {
3020            return res; // connection is closed
3021        } else if (all_results.error) {
3022            res->error(all_results.error->to_json());
3023            return res;
3024        } else {
3025            json arr = json::array();
3026            for (auto & res : all_results.results) {
3027                GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
3028                arr.push_back(res->to_json());
3029            }
3030            GGML_ASSERT(!arr.empty() && "empty results");
3031            if (arr.size() == 1) {
3032                // if single request, return single object instead of array
3033                res->ok(arr[0]);
3034            } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
3035                // if multiple results in OAI format, we need to re-format them
3036                json & choices = arr[0]["choices"];
3037                for (size_t i = 1; i < arr.size(); i++) {
3038                    choices.push_back(std::move(arr[i]["choices"][0]));
3039                }
3040                res->ok(arr[0]);
3041            } else {
3042                // multi-results, non-OAI compat
3043                res->ok(arr);
3044            }
3045        }
3046    } else {
3047        // in streaming mode, the first error must be treated as non-stream response
3048        // this is to match the OAI API behavior
3049        // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
3050        auto first_result = rd.next(req.should_stop);
3051        if (first_result == nullptr) {
3052            GGML_ASSERT(req.should_stop());
3053            return res; // connection is closed
3054        }
3055
3056        if (first_result->is_error()) {
3057            res->error(first_result->to_json());
3058            return res;
3059        }
3060
3061        GGML_ASSERT(
3062            dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr ||
3063            dynamic_cast<server_task_result_cmpl_final*>  (first_result.get()) != nullptr
3064        );
3065
3066        // next responses are streamed
3067        // to be sent immediately
3068        json first_result_json = first_result->to_json();
3069        if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
3070            res->data = format_anthropic_sse(first_result_json);
3071        } else if (res_type == TASK_RESPONSE_TYPE_OAI_RESP) {
3072            res->data = format_oai_resp_sse(first_result_json);
3073        } else {
3074            res->data = format_oai_sse(first_result_json);
3075        }
3076        res->status = 200;
3077        res->content_type = "text/event-stream";
3078        res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool {
3079            static auto format_error = [](task_response_type res_type, const json & res_json) {
3080                if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
3081                    return format_anthropic_sse({
3082                        {"event", "error"},
3083                        {"data", res_json},
3084                    });
3085                } else {
3086                    return format_oai_sse(json {{ "error", res_json }});
3087                }
3088            };
3089
3090            try {
3091                if (req.should_stop()) {
3092                    SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
3093                    return false; // should_stop condition met
3094                }
3095
3096                if (!res_this->data.empty()) {
3097                    // flush the first chunk
3098                    output = std::move(res_this->data);
3099                    res_this->data.clear();
3100                    return true;
3101                }
3102
3103                server_response_reader & rd = res_this->rd;
3104
3105                // check if there is more data
3106                if (!rd.has_next()) {
3107                    switch (res_type) {
3108                        case TASK_RESPONSE_TYPE_NONE:
3109                        case TASK_RESPONSE_TYPE_OAI_RESP:
3110                        case TASK_RESPONSE_TYPE_ANTHROPIC:
3111                            output = "";
3112                            break;
3113
3114                        default:
3115                            output = "data: [DONE]\n\n";
3116                            break;
3117                    }
3118                    SRV_DBG("%s", "all results received, terminating stream\n");
3119                    return false; // no more data, terminate
3120                }
3121
3122                // receive subsequent results
3123                auto result = rd.next(req.should_stop);
3124                if (result == nullptr) {
3125                    SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
3126                    GGML_ASSERT(req.should_stop());
3127                    return false; // should_stop condition met
3128                }
3129
3130                // send the results
3131                if (result->is_error()) {
3132                    json res_json = result->to_json();
3133                    output = format_error(res_type, res_json);
3134                    SRV_DBG("%s", "error received during streaming, terminating stream\n");
3135                    return false; // terminate on error
3136                } else {
3137                    GGML_ASSERT(
3138                        dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
3139                        || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
3140                    );
3141                    json res_json = result->to_json();
3142                    if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
3143                        output = format_anthropic_sse(res_json);
3144                    } else if (res_type == TASK_RESPONSE_TYPE_OAI_RESP) {
3145                        output = format_oai_resp_sse(res_json);
3146                    } else {
3147                        output = format_oai_sse(res_json);
3148                    }
3149                }
3150
3151                // has next data, continue
3152                return true;
3153
3154            } catch (const std::exception & e) {
3155                json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
3156                output = format_error(res_type, error_json);
3157
3158                // terminate on exception
3159                return false;
3160            }
3161        };
3162    }
3163
3164    return res;
3165}
3166
3167std::unique_ptr<server_res_generator> server_routes::create_response(bool bypass_sleep) {
3168    return std::make_unique<server_res_generator>(queue_tasks, queue_results, params.sleep_idle_seconds, bypass_sleep);
3169}
3170
3171server_routes::server_routes(const common_params & params, server_context & ctx_server)
3172        : params(params),
3173          ctx_server(*ctx_server.impl),
3174          queue_tasks(ctx_server.impl->queue_tasks),
3175          queue_results(ctx_server.impl->queue_results) {
3176    init_routes();
3177}
3178
3179void server_routes::init_routes() {
3180    // IMPORTANT: all lambda functions must start with create_response()
3181    // this is to ensure that the server_res_generator can handle sleeping case correctly
3182
3183    this->get_health = [this](const server_http_req &) {
3184        // error and loading states are handled by middleware
3185        auto res = create_response(true);
3186
3187        // this endpoint can be accessed during sleeping
3188        // the next LOC is to avoid someone accidentally use ctx_server
3189        bool ctx_server; // do NOT delete this line
3190        GGML_UNUSED(ctx_server);
3191
3192        res->ok({{"status", "ok"}});
3193        return res;
3194    };
3195
3196    this->get_metrics = [this](const server_http_req & req) {
3197        auto res = create_response();
3198        if (!params.endpoint_metrics) {
3199            res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
3200            return res;
3201        }
3202
3203        // request slots data using task queue
3204        {
3205            server_task task(SERVER_TASK_TYPE_METRICS);
3206            task.id = res->rd.get_new_id();
3207            res->rd.post_task(std::move(task), true); // high-priority task
3208        }
3209
3210        // get the result
3211        auto result = res->rd.next(req.should_stop);
3212        if (!result) {
3213            // connection was closed
3214            GGML_ASSERT(req.should_stop());
3215            return res;
3216        }
3217
3218        if (result->is_error()) {
3219            res->error(result->to_json());
3220            return res;
3221        }
3222
3223        // TODO: get rid of this dynamic_cast
3224        auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
3225        GGML_ASSERT(res_task != nullptr);
3226
3227        // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
3228        json all_metrics_def = json {
3229            {"counter", {{
3230                    {"name",  "prompt_tokens_total"},
3231                    {"help",  "Number of prompt tokens processed."},
3232                    {"value",  (uint64_t) res_task->n_prompt_tokens_processed_total}
3233            }, {
3234                    {"name",  "prompt_seconds_total"},
3235                    {"help",  "Prompt process time"},
3236                    {"value",  (uint64_t) res_task->t_prompt_processing_total / 1.e3}
3237            }, {
3238                    {"name",  "tokens_predicted_total"},
3239                    {"help",  "Number of generation tokens processed."},
3240                    {"value",  (uint64_t) res_task->n_tokens_predicted_total}
3241            }, {
3242                    {"name",  "tokens_predicted_seconds_total"},
3243                    {"help",  "Predict process time"},
3244                    {"value",  (uint64_t) res_task->t_tokens_generation_total / 1.e3}
3245            }, {
3246                    {"name",  "n_decode_total"},
3247                    {"help",  "Total number of llama_decode() calls"},
3248                    {"value",  res_task->n_decode_total}
3249            }, {
3250                    {"name",  "n_tokens_max"},
3251                    {"help",  "Largest observed n_tokens."},
3252                    {"value",  res_task->n_tokens_max}
3253            }, {
3254                    {"name",  "n_busy_slots_per_decode"},
3255                    {"help",  "Average number of busy slots per llama_decode() call"},
3256                    {"value",  (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
3257            }}},
3258            {"gauge", {{
3259                    {"name",  "prompt_tokens_seconds"},
3260                    {"help",  "Average prompt throughput in tokens/s."},
3261                    {"value",  res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
3262            },{
3263                    {"name",  "predicted_tokens_seconds"},
3264                    {"help",  "Average generation throughput in tokens/s."},
3265                    {"value",  res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
3266            },{
3267                    {"name",  "requests_processing"},
3268                    {"help",  "Number of requests processing."},
3269                    {"value",  (uint64_t) res_task->n_processing_slots}
3270            },{
3271                    {"name",  "requests_deferred"},
3272                    {"help",  "Number of requests deferred."},
3273                    {"value",  (uint64_t) res_task->n_tasks_deferred}
3274            }}}
3275        };
3276
3277        std::stringstream prometheus;
3278
3279        for (const auto & el : all_metrics_def.items()) {
3280            const auto & type        = el.key();
3281            const auto & metrics_def = el.value();
3282
3283            for (const auto & metric_def : metrics_def) {
3284                const std::string name = metric_def.at("name");
3285                const std::string help = metric_def.at("help");
3286
3287                auto value = json_value(metric_def, "value", 0.);
3288                prometheus << "# HELP llamacpp:" << name << " " << help  << "\n"
3289                            << "# TYPE llamacpp:" << name << " " << type  << "\n"
3290                            << "llamacpp:"        << name << " " << value << "\n";
3291            }
3292        }
3293
3294        res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
3295        res->content_type = "text/plain; version=0.0.4";
3296        res->status = 200;
3297        res->data = prometheus.str();
3298        return res;
3299    };
3300
3301    this->get_slots = [this](const server_http_req & req) {
3302        auto res = create_response();
3303        if (!params.endpoint_slots) {
3304            res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
3305            return res;
3306        }
3307
3308        // request slots data using task queue
3309        {
3310            server_task task(SERVER_TASK_TYPE_METRICS);
3311            task.id = res->rd.get_new_id();
3312            res->rd.post_task(std::move(task), true); // high-priority task
3313        }
3314
3315        // get the result
3316        auto result = res->rd.next(req.should_stop);
3317        if (!result) {
3318            // connection was closed
3319            GGML_ASSERT(req.should_stop());
3320            return res;
3321        }
3322
3323        if (result->is_error()) {
3324            res->error(result->to_json());
3325            return res;
3326        }
3327
3328        // TODO: get rid of this dynamic_cast
3329        auto * res_task = dynamic_cast<server_task_result_metrics*>(result.get());
3330        GGML_ASSERT(res_task != nullptr);
3331
3332        // optionally return "fail_on_no_slot" error
3333        if (!req.get_param("fail_on_no_slot").empty()) {
3334            if (res_task->n_idle_slots == 0) {
3335                res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
3336                return res;
3337            }
3338        }
3339
3340        res->ok(res_task->slots_data);
3341        return res;
3342    };
3343
3344    this->post_slots = [this](const server_http_req & req) {
3345        auto res = create_response();
3346        if (params.slot_save_path.empty()) {
3347            res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
3348            return res;
3349        }
3350
3351        std::string id_slot_str = req.get_param("id_slot");
3352
3353        int id_slot;
3354        try {
3355            id_slot = std::stoi(id_slot_str);
3356        } catch (const std::exception &) {
3357            res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
3358            return res;
3359        }
3360
3361        std::string action = req.get_param("action");
3362
3363        if (action == "save") {
3364            return handle_slots_save(req, id_slot);
3365        }
3366        if (action == "restore") {
3367            return handle_slots_restore(req, id_slot);
3368        }
3369        if (action == "erase") {
3370            return handle_slots_erase(req, id_slot);
3371        }
3372
3373        res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
3374        return res;
3375    };
3376
3377    this->get_props = [this](const server_http_req &) {
3378        auto res = create_response(true);
3379
3380        // this endpoint can be accessed during sleeping
3381        // the next LOC is to avoid someone accidentally use ctx_server
3382        bool ctx_server; // do NOT delete this line
3383        GGML_UNUSED(ctx_server);
3384
3385        task_params tparams;
3386        tparams.sampling = params.sampling;
3387        json default_generation_settings_for_props = json {
3388            { "params", tparams.to_json(true) },
3389            { "n_ctx",  meta->slot_n_ctx },
3390        };
3391
3392        std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
3393        std::string tmpl_tools   = common_chat_templates_source(meta->chat_params.tmpls.get(), "tool_use");
3394
3395        json props = {
3396            { "default_generation_settings", default_generation_settings_for_props },
3397            { "total_slots",                 params.n_parallel },
3398            { "model_alias",                 meta->model_name },
3399            { "model_path",                  meta->model_path },
3400            { "modalities",                  json {
3401                {"vision", meta->has_inp_image},
3402                {"audio",  meta->has_inp_audio},
3403            } },
3404            { "endpoint_slots",              params.endpoint_slots },
3405            { "endpoint_props",              params.endpoint_props },
3406            { "endpoint_metrics",            params.endpoint_metrics },
3407            { "webui",                       params.webui },
3408            { "webui_settings",              meta->json_webui_settings },
3409            { "chat_template",               tmpl_default },
3410            { "chat_template_caps",          meta->chat_template_caps },
3411            { "bos_token",                   meta->bos_token_str },
3412            { "eos_token",                   meta->eos_token_str },
3413            { "build_info",                  meta->build_info },
3414            { "is_sleeping",                 queue_tasks.is_sleeping() },
3415        };
3416        if (params.use_jinja) {
3417            if (!tmpl_tools.empty()) {
3418                props["chat_template_tool_use"] = tmpl_tools;
3419            }
3420        }
3421        res->ok(props);
3422        return res;
3423    };
3424
3425    this->post_props = [this](const server_http_req &) {
3426        auto res = create_response();
3427        if (!params.endpoint_props) {
3428            res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
3429            return res;
3430        }
3431        // update any props here
3432
3433        res->ok({{ "success", true }});
3434        return res;
3435    };
3436
3437    this->get_api_show = [this](const server_http_req &) {
3438        auto res = create_response();
3439        std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
3440        json data = {
3441            {
3442                "model_info", {
3443                    { "llama.context_length", meta->slot_n_ctx },
3444                }
3445            },
3446            {"modelfile", ""},
3447            {"parameters", ""},
3448            {"template", tmpl_default},
3449            {"details", {
3450                {"parent_model", ""},
3451                {"format", "gguf"},
3452                {"family", ""},
3453                {"families", {""}},
3454                {"parameter_size", ""},
3455                {"quantization_level", ""}
3456            }},
3457            {"model_info", ""},
3458            {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
3459        };
3460
3461        res->ok(data);
3462        return res;
3463    };
3464
3465    this->post_infill = [this](const server_http_req & req) {
3466        auto res = create_response();
3467        // check model compatibility
3468        std::string err;
3469        if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
3470            err += "prefix token is missing. ";
3471        }
3472        if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
3473            err += "suffix token is missing. ";
3474        }
3475        if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
3476            err += "middle token is missing. ";
3477        }
3478        if (!err.empty()) {
3479            res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
3480            return res;
3481        }
3482
3483        // validate input
3484        json data = json::parse(req.body);
3485        if (data.contains("prompt") && !data.at("prompt").is_string()) {
3486            // prompt is optional
3487            res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3488        }
3489
3490        if (!data.contains("input_prefix")) {
3491            res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
3492        }
3493
3494        if (!data.contains("input_suffix")) {
3495            res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
3496        }
3497
3498        if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
3499            // input_extra is optional
3500            res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
3501            return res;
3502        }
3503
3504        json input_extra = json_value(data, "input_extra", json::array());
3505        for (const auto & chunk : input_extra) {
3506            // { "text": string, "filename": string }
3507            if (!chunk.contains("text") || !chunk.at("text").is_string()) {
3508                res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
3509                return res;
3510            }
3511            // filename is optional
3512            if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
3513                res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
3514                return res;
3515            }
3516        }
3517        data["input_extra"] = input_extra; // default to empty array if it's not exist
3518
3519        std::string prompt = json_value(data, "prompt", std::string());
3520        std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
3521        SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
3522        data["prompt"] = format_prompt_infill(
3523            ctx_server.vocab,
3524            data.at("input_prefix"),
3525            data.at("input_suffix"),
3526            data.at("input_extra"),
3527            params.n_batch,
3528            params.n_predict,
3529            meta->slot_n_ctx,
3530            params.spm_infill,
3531            tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
3532        );
3533
3534        std::vector<raw_buffer> files; // dummy
3535        return handle_completions_impl(
3536            req,
3537            SERVER_TASK_TYPE_INFILL,
3538            data,
3539            files,
3540            TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
3541    };
3542
3543    this->post_completions = [this](const server_http_req & req) {
3544        auto res = create_response();
3545        std::vector<raw_buffer> files; // dummy
3546        const json body = json::parse(req.body);
3547        return handle_completions_impl(
3548            req,
3549            SERVER_TASK_TYPE_COMPLETION,
3550            body,
3551            files,
3552            TASK_RESPONSE_TYPE_NONE);
3553    };
3554
3555    this->post_completions_oai = [this](const server_http_req & req) {
3556        auto res = create_response();
3557        std::vector<raw_buffer> files; // dummy
3558        const json body = json::parse(req.body);
3559        return handle_completions_impl(
3560            req,
3561            SERVER_TASK_TYPE_COMPLETION,
3562            body,
3563            files,
3564            TASK_RESPONSE_TYPE_OAI_CMPL);
3565    };
3566
3567    this->post_chat_completions = [this](const server_http_req & req) {
3568        auto res = create_response();
3569        std::vector<raw_buffer> files;
3570        json body = json::parse(req.body);
3571        json body_parsed = oaicompat_chat_params_parse(
3572            body,
3573            meta->chat_params,
3574            files);
3575        return handle_completions_impl(
3576            req,
3577            SERVER_TASK_TYPE_COMPLETION,
3578            body_parsed,
3579            files,
3580            TASK_RESPONSE_TYPE_OAI_CHAT);
3581    };
3582
3583    this->post_responses_oai = [this](const server_http_req & req) {
3584        auto res = create_response();
3585        std::vector<raw_buffer> files;
3586        json body = convert_responses_to_chatcmpl(json::parse(req.body));
3587        SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions");
3588        SRV_DBG("converted request: %s\n", body.dump().c_str());
3589        json body_parsed = oaicompat_chat_params_parse(
3590            body,
3591            meta->chat_params,
3592            files);
3593        return handle_completions_impl(
3594            req,
3595            SERVER_TASK_TYPE_COMPLETION,
3596            body_parsed,
3597            files,
3598            TASK_RESPONSE_TYPE_OAI_RESP);
3599    };
3600
3601    this->post_anthropic_messages = [this](const server_http_req & req) {
3602        auto res = create_response();
3603        std::vector<raw_buffer> files;
3604        json body = convert_anthropic_to_oai(json::parse(req.body));
3605        SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
3606        SRV_DBG("converted request: %s\n", body.dump().c_str());
3607        json body_parsed = oaicompat_chat_params_parse(
3608            body,
3609            meta->chat_params,
3610            files);
3611        return handle_completions_impl(
3612            req,
3613            SERVER_TASK_TYPE_COMPLETION,
3614            body_parsed,
3615            files,
3616            TASK_RESPONSE_TYPE_ANTHROPIC);
3617    };
3618
3619    this->post_anthropic_count_tokens = [this](const server_http_req & req) {
3620        auto res = create_response();
3621        std::vector<raw_buffer> files;
3622        json body = convert_anthropic_to_oai(json::parse(req.body));
3623        SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions");
3624        SRV_DBG("converted request: %s\n", body.dump().c_str());
3625        json body_parsed = oaicompat_chat_params_parse(
3626            body,
3627            meta->chat_params,
3628            files);
3629
3630        json prompt = body_parsed.at("prompt");
3631        llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
3632        res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
3633        return res;
3634    };
3635
3636    // same with handle_chat_completions, but without inference part
3637    this->post_apply_template = [this](const server_http_req & req) {
3638        auto res = create_response();
3639        std::vector<raw_buffer> files; // dummy, unused
3640        json body = json::parse(req.body);
3641        json data = oaicompat_chat_params_parse(
3642            body,
3643            meta->chat_params,
3644            files);
3645        res->ok({{ "prompt", std::move(data.at("prompt")) }});
3646        return res;
3647    };
3648
3649    this->get_models = [this](const server_http_req &) {
3650        auto res = create_response(true);
3651
3652        // this endpoint can be accessed during sleeping
3653        // the next LOC is to avoid someone accidentally use ctx_server
3654        bool ctx_server; // do NOT delete this line
3655        GGML_UNUSED(ctx_server);
3656
3657        json models = {
3658            {"models", {
3659                {
3660                    {"name",  meta->model_name},
3661                    {"model", meta->model_name},
3662                    {"modified_at", ""},
3663                    {"size", ""},
3664                    {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
3665                    {"type", "model"},
3666                    {"description", ""},
3667                    {"tags", {""}},
3668                    {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
3669                    {"parameters", ""},
3670                    {"details", {
3671                        {"parent_model", ""},
3672                        {"format", "gguf"},
3673                        {"family", ""},
3674                        {"families", {""}},
3675                        {"parameter_size", ""},
3676                        {"quantization_level", ""}
3677                    }}
3678                }
3679            }},
3680            {"object", "list"},
3681            {"data", {
3682                {
3683                    {"id",       meta->model_name},
3684                    {"object",   "model"},
3685                    {"created",  std::time(0)},
3686                    {"owned_by", "llamacpp"},
3687                    {"meta",     {
3688                        {"vocab_type",  meta->model_vocab_type},
3689                        {"n_vocab",     meta->model_vocab_n_tokens},
3690                        {"n_ctx_train", meta->model_n_ctx_train},
3691                        {"n_embd",      meta->model_n_embd_inp},
3692                        {"n_params",    meta->model_n_params},
3693                        {"size",        meta->model_size},
3694                    }},
3695                },
3696            }}
3697        };
3698
3699        res->ok(models);
3700        return res;
3701    };
3702
3703    this->post_tokenize = [this](const server_http_req & req) {
3704        auto res = create_response();
3705        const json body = json::parse(req.body);
3706        json tokens_response = json::array();
3707        if (body.count("content") != 0) {
3708            const bool add_special = json_value(body, "add_special", false);
3709            const bool parse_special = json_value(body, "parse_special", true);
3710            const bool with_pieces = json_value(body, "with_pieces", false);
3711
3712            llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special);
3713
3714            if (with_pieces) {
3715                for (const auto& token : tokens) {
3716                    std::string piece = common_token_to_piece(ctx_server.vocab, token);
3717                    json piece_json;
3718
3719                    // Check if the piece is valid UTF-8
3720                    if (is_valid_utf8(piece)) {
3721                        piece_json = piece;
3722                    } else {
3723                        // If not valid UTF-8, store as array of byte values
3724                        piece_json = json::array();
3725                        for (unsigned char c : piece) {
3726                            piece_json.push_back(static_cast<int>(c));
3727                        }
3728                    }
3729
3730                    tokens_response.push_back({
3731                        {"id", token},
3732                        {"piece", piece_json}
3733                    });
3734                }
3735            } else {
3736                tokens_response = tokens;
3737            }
3738        }
3739
3740        res->ok(json{{"tokens", std::move(tokens_response)}});
3741        return res;
3742    };
3743
3744    this->post_detokenize = [this](const server_http_req & req) {
3745        auto res = create_response();
3746        const json body = json::parse(req.body);
3747
3748        std::string content;
3749        if (body.count("tokens") != 0) {
3750            const llama_tokens tokens = body.at("tokens");
3751            content = tokens_to_str(ctx_server.vocab, tokens);
3752        }
3753
3754        res->ok(json{{"content", std::move(content)}});
3755        return res;
3756    };
3757
3758    this->post_embeddings = [this](const server_http_req & req) {
3759        return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
3760    };
3761
3762    this->post_embeddings_oai = [this](const server_http_req & req) {
3763        return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
3764    };
3765
3766    this->post_rerank = [this](const server_http_req & req) {
3767        auto res = create_response();
3768        if (!params.embedding || params.pooling_type != LLAMA_POOLING_TYPE_RANK) {
3769            res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3770            return res;
3771        }
3772
3773        const json body = json::parse(req.body);
3774
3775        // if true, use TEI API format, otherwise use Jina API format
3776        // Jina: https://jina.ai/reranker/
3777        // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
3778        bool is_tei_format = body.contains("texts");
3779
3780        json query;
3781        if (body.count("query") == 1) {
3782            query = body.at("query");
3783            if (!query.is_string()) {
3784                res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3785                return res;
3786            }
3787        } else {
3788            res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3789            return res;
3790        }
3791
3792        std::vector<std::string> documents = json_value(body, "documents",
3793                                             json_value(body, "texts", std::vector<std::string>()));
3794        if (documents.empty()) {
3795            res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3796            return res;
3797        }
3798
3799        int top_n = json_value(body, "top_n", (int)documents.size());
3800
3801        // create and queue the task
3802        json responses = json::array();
3803        auto & rd = res->rd;
3804        {
3805            std::vector<server_task> tasks;
3806            tasks.reserve(documents.size());
3807            for (size_t i = 0; i < documents.size(); i++) {
3808                auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
3809                server_task task = server_task(SERVER_TASK_TYPE_RERANK);
3810                task.id     = rd.get_new_id();
3811                task.tokens = std::move(tmp);
3812                tasks.push_back(std::move(task));
3813            }
3814            rd.post_tasks(std::move(tasks));
3815        }
3816
3817        // wait for the results
3818        auto all_results = rd.wait_for_all(req.should_stop);
3819
3820        // collect results
3821        if (all_results.is_terminated) {
3822            return res; // connection is closed
3823        } else if (all_results.error) {
3824            res->error(all_results.error->to_json());
3825            return res;
3826        } else {
3827            for (auto & res : all_results.results) {
3828                GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
3829                responses.push_back(res->to_json());
3830            }
3831        }
3832
3833        // write JSON response
3834        json root = format_response_rerank(
3835            body,
3836            meta->model_name,
3837            responses,
3838            is_tei_format,
3839            documents,
3840            top_n);
3841
3842        res->ok(root);
3843        return res;
3844    };
3845
3846    this->get_lora_adapters = [this](const server_http_req & req) {
3847        auto res = create_response();
3848
3849        auto & rd = res->rd;
3850        {
3851            server_task task(SERVER_TASK_TYPE_GET_LORA);
3852            task.id = rd.get_new_id();
3853            rd.post_task(std::move(task));
3854        }
3855
3856        // get the result
3857        auto result = rd.next(req.should_stop);
3858        if (!result) {
3859            // connection was closed
3860            GGML_ASSERT(req.should_stop());
3861            return res;
3862        }
3863
3864        if (result->is_error()) {
3865            res->error(result->to_json());
3866            return res;
3867        }
3868
3869        GGML_ASSERT(dynamic_cast<server_task_result_get_lora*>(result.get()) != nullptr);
3870        res->ok(result->to_json());
3871        return res;
3872    };
3873
3874    this->post_lora_adapters = [this](const server_http_req & req) {
3875        auto res = create_response();
3876        const json body = json::parse(req.body);
3877        if (!body.is_array()) {
3878            res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
3879            return res;
3880        }
3881
3882        auto & rd = res->rd;
3883        {
3884            server_task task(SERVER_TASK_TYPE_SET_LORA);
3885            task.id = rd.get_new_id();
3886            task.set_lora = parse_lora_request(body);
3887            rd.post_task(std::move(task));
3888        }
3889
3890        // get the result
3891        auto result = rd.next(req.should_stop);
3892        if (!result) {
3893            // connection was closed
3894            GGML_ASSERT(req.should_stop());
3895            return res;
3896        }
3897
3898        if (result->is_error()) {
3899            res->error(result->to_json());
3900            return res;
3901        }
3902
3903        GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
3904        res->ok(result->to_json());
3905        return res;
3906    };
3907}
3908
3909std::unique_ptr<server_res_generator> server_routes::handle_slots_save(const server_http_req & req, int id_slot) {
3910    auto res = create_response();
3911    const json request_data = json::parse(req.body);
3912    std::string filename = request_data.at("filename");
3913    if (!fs_validate_filename(filename)) {
3914        res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
3915        return res;
3916    }
3917    std::string filepath = params.slot_save_path + filename;
3918
3919    auto & rd = res->rd;
3920    {
3921        server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
3922        task.id = rd.get_new_id();
3923        task.slot_action.id_slot  = id_slot;
3924        task.slot_action.filename = filename;
3925        task.slot_action.filepath = filepath;
3926        rd.post_task(std::move(task));
3927    }
3928
3929    auto result = rd.next(req.should_stop);
3930    if (!result) {
3931        // connection was closed
3932        GGML_ASSERT(req.should_stop());
3933        return res;
3934    }
3935
3936    if (result->is_error()) {
3937        res->error(result->to_json());
3938        return res;
3939    }
3940
3941    res->ok(result->to_json());
3942    return res;
3943}
3944
3945std::unique_ptr<server_res_generator> server_routes::handle_slots_restore(const server_http_req & req, int id_slot) {
3946    auto res = create_response();
3947    const json request_data = json::parse(req.body);
3948    std::string filename = request_data.at("filename");
3949    if (!fs_validate_filename(filename)) {
3950        res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
3951        return res;
3952    }
3953    std::string filepath = params.slot_save_path + filename;
3954
3955    auto & rd = res->rd;
3956    {
3957        server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
3958        task.id = rd.get_new_id();
3959        task.slot_action.id_slot  = id_slot;
3960        task.slot_action.filename = filename;
3961        task.slot_action.filepath = filepath;
3962        rd.post_task(std::move(task));
3963    }
3964
3965    auto result = rd.next(req.should_stop);
3966    if (!result) {
3967        // connection was closed
3968        GGML_ASSERT(req.should_stop());
3969        return res;
3970    }
3971
3972    if (result->is_error()) {
3973        res->error(result->to_json());
3974        return res;
3975    }
3976
3977    GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
3978    res->ok(result->to_json());
3979    return res;
3980}
3981
3982std::unique_ptr<server_res_generator> server_routes::handle_slots_erase(const server_http_req & req, int id_slot) {
3983    auto res = create_response();
3984    auto & rd = res->rd;
3985    {
3986        server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
3987        task.id = rd.get_new_id();
3988        task.slot_action.id_slot = id_slot;
3989        rd.post_task(std::move(task));
3990    }
3991
3992    auto result = rd.next(req.should_stop);
3993    if (!result) {
3994        // connection was closed
3995        GGML_ASSERT(req.should_stop());
3996        return res;
3997    }
3998
3999    if (result->is_error()) {
4000        res->error(result->to_json());
4001        return res;
4002    }
4003
4004    GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
4005    res->ok(result->to_json());
4006    return res;
4007}
4008
4009std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
4010    auto res = create_response();
4011    if (!params.embedding) {
4012        res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
4013        return res;
4014    }
4015
4016    if (res_type != TASK_RESPONSE_TYPE_NONE && meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
4017        res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
4018        return res;
4019    }
4020
4021    const json body = json::parse(req.body);
4022
4023    // for the shape of input/content, see tokenize_input_prompts()
4024    json prompt;
4025    if (body.count("input") != 0) {
4026        prompt = body.at("input");
4027    } else if (body.contains("content")) {
4028        res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
4029        prompt = body.at("content");
4030    } else {
4031        res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
4032        return res;
4033    }
4034
4035    bool use_base64 = false;
4036    if (body.count("encoding_format") != 0) {
4037        const std::string & format = body.at("encoding_format");
4038        if (format == "base64") {
4039            use_base64 = true;
4040        } else if (format != "float") {
4041            res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
4042            return res;
4043        }
4044    }
4045
4046    auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
4047    for (const auto & tokens : tokenized_prompts) {
4048        // this check is necessary for models that do not add BOS token to the input
4049        if (tokens.empty()) {
4050            res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
4051            return res;
4052        }
4053    }
4054
4055    int embd_normalize = 2; // default to Euclidean/L2 norm
4056    if (body.count("embd_normalize") != 0) {
4057        embd_normalize = body.at("embd_normalize");
4058        if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
4059            SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type);
4060        }
4061    }
4062
4063    // create and queue the task
4064    json responses = json::array();
4065    auto & rd = res->rd;
4066    {
4067        std::vector<server_task> tasks;
4068        for (size_t i = 0; i < tokenized_prompts.size(); i++) {
4069            server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
4070
4071            task.id     = rd.get_new_id();
4072            task.tokens = std::move(tokenized_prompts[i]);
4073
4074            // OAI-compat
4075            task.params.res_type = res_type;
4076            task.params.embd_normalize = embd_normalize;
4077
4078            tasks.push_back(std::move(task));
4079        }
4080        rd.post_tasks(std::move(tasks));
4081    }
4082
4083    // wait for the results
4084    auto all_results = rd.wait_for_all(req.should_stop);
4085
4086    // collect results
4087    if (all_results.is_terminated) {
4088        return res; // connection is closed
4089    } else if (all_results.error) {
4090        res->error(all_results.error->to_json());
4091        return res;
4092    } else {
4093        for (auto & res : all_results.results) {
4094            GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
4095            responses.push_back(res->to_json());
4096        }
4097    }
4098
4099    // write JSON response
4100    json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
4101        ? format_embeddings_response_oaicompat(body, meta->model_name, responses, use_base64)
4102        : json(responses);
4103    res->ok(root);
4104    return res;
4105}