1#pragma once
  2
  3#include "common.h"
  4#include "llama.h"
  5
  6#include <string>
  7#include <unordered_set>
  8#include <list>
  9#include <map>
 10
 11// TODO: prevent including the whole server-common.h as we only use server_tokens
 12#include "server-common.h"
 13
 14using json = nlohmann::ordered_json;
 15
 16enum server_task_type {
 17    SERVER_TASK_TYPE_COMPLETION,
 18    SERVER_TASK_TYPE_EMBEDDING,
 19    SERVER_TASK_TYPE_RERANK,
 20    SERVER_TASK_TYPE_INFILL,
 21    SERVER_TASK_TYPE_CANCEL,
 22    SERVER_TASK_TYPE_NEXT_RESPONSE,
 23    SERVER_TASK_TYPE_METRICS,
 24    SERVER_TASK_TYPE_SLOT_SAVE,
 25    SERVER_TASK_TYPE_SLOT_RESTORE,
 26    SERVER_TASK_TYPE_SLOT_ERASE,
 27    SERVER_TASK_TYPE_GET_LORA,
 28    SERVER_TASK_TYPE_SET_LORA,
 29};
 30
 31// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
 32enum task_response_type {
 33    TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
 34    TASK_RESPONSE_TYPE_OAI_CHAT,
 35    TASK_RESPONSE_TYPE_OAI_CMPL,
 36    TASK_RESPONSE_TYPE_OAI_RESP,
 37    TASK_RESPONSE_TYPE_OAI_EMBD,
 38    TASK_RESPONSE_TYPE_ANTHROPIC,
 39};
 40
 41enum stop_type {
 42    STOP_TYPE_NONE,
 43    STOP_TYPE_EOS,
 44    STOP_TYPE_WORD,
 45    STOP_TYPE_LIMIT,
 46};
 47
 48struct task_params {
 49    bool stream          = true;
 50    bool include_usage   = false;
 51    bool cache_prompt    = true; // remember the prompt to avoid reprocessing all prompt
 52    bool return_tokens   = false;
 53    bool return_progress = false;
 54
 55    int32_t n_keep    =  0; // number of tokens to keep from initial prompt
 56    int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
 57    int32_t n_predict = -1; // new tokens to predict
 58    int32_t n_indent  =  0; // minimum line indentation for the generated text in number of whitespace characters
 59    int32_t n_cmpl    =  1; // number of completions to generate from this prompt
 60
 61    int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
 62
 63    int64_t t_max_prompt_ms  = -1; // TODO: implement
 64    int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 65
 66    std::map<int, float> lora; // mapping adapter ID -> scale
 67
 68    std::vector<std::string> antiprompt;
 69    std::vector<std::string> response_fields;
 70
 71    bool timings_per_token   = false;
 72    bool post_sampling_probs = false;
 73
 74    struct common_params_sampling sampling;
 75    struct common_params_speculative speculative;
 76
 77    // response formatting
 78    bool               verbose  = false;
 79    task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
 80    std::string        oaicompat_model;
 81    std::string        oaicompat_cmpl_id;
 82
 83    // per-request parameters for chat parsing
 84    common_chat_parser_params chat_parser_params;
 85
 86    // Embeddings
 87    int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
 88
 89    json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
 90    json to_json(bool only_metrics = false) const;
 91};
 92
 93// struct for tracking the state of a task (e.g., for streaming)
 94struct task_result_state {
 95    // tracking diffs for partial tool calls
 96    std::vector<common_chat_msg_diff> diffs;
 97    common_chat_parser_params chat_parser_params;
 98    common_chat_msg chat_msg;
 99    std::string generated_text; // append new chunks of generated text here
100    std::vector<std::string> generated_tool_call_ids;
101
102    // for OpenAI Responses and Anthropic streaming API:
103    // track output item / content block state across chunks
104    bool thinking_block_started = false;
105    bool text_block_started = false;
106
107    // for OpenAI Responses streaming API
108    const std::string oai_resp_id;
109    const std::string oai_resp_reasoning_id;
110    const std::string oai_resp_message_id;
111    std::string oai_resp_fc_id; // function call ID for current args delta
112
113    task_result_state(const common_chat_parser_params & chat_parser_params)
114        : chat_parser_params(chat_parser_params)
115        , oai_resp_id("resp_" + random_string())
116        , oai_resp_reasoning_id("rs_" + random_string())
117        , oai_resp_message_id("msg_" + random_string()) {}
118
119    // parse partial tool calls and update the internal state
120    common_chat_msg update_chat_msg(
121        const std::string & text_added,
122        bool is_partial,
123        std::vector<common_chat_msg_diff> & diffs);
124};
125
126struct server_task {
127    int id = -1; // to be filled by server_queue
128
129    // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
130    size_t index = 0; // used when there are multiple prompts (batch request)
131
132    // used by SERVER_TASK_TYPE_CANCEL
133    int id_target = -1;
134    int id_slot   = -1;
135
136    // used by parallel sampling (multiple completions from same prompt)
137    int id_parent  = -1;
138    // temporary store of child tasks for scheduling
139    // note: accessing to elements is invalid after the task is moved to server_slot
140    std::vector<server_task> child_tasks;
141
142    // used by SERVER_TASK_TYPE_INFERENCE
143    task_params   params;
144    server_tokens tokens;
145
146    // only used by CLI, this allow tokenizing CLI inputs on server side
147    // we need this because mtmd_context and vocab are not accessible outside of server_context
148    bool                    cli = false;
149    std::string             cli_prompt;
150    std::vector<raw_buffer> cli_files;
151
152    server_task_type type;
153
154    // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
155    struct slot_action {
156        int id_slot;
157        std::string filename;
158        std::string filepath;
159    };
160    slot_action slot_action;
161
162    // used by SERVER_TASK_TYPE_METRICS
163    bool metrics_reset_bucket = false;
164
165    // used by SERVER_TASK_TYPE_SET_LORA
166    std::map<int, float> set_lora; // mapping adapter ID -> scale
167
168    server_task() = default;
169
170    server_task(server_task_type type) : type(type) {}
171
172    int32_t n_tokens() const {
173        return tokens.size();
174    }
175
176    bool need_embd() const {
177        switch (type) {
178            case SERVER_TASK_TYPE_EMBEDDING:
179            case SERVER_TASK_TYPE_RERANK:
180                return true;
181            default:
182                return false;
183        }
184    }
185
186    bool need_logits() const {
187        switch (type) {
188            case SERVER_TASK_TYPE_COMPLETION:
189            case SERVER_TASK_TYPE_INFILL:
190                return true;
191            default:
192                return false;
193        }
194    }
195
196    bool need_sampling() const {
197        switch (type) {
198            case SERVER_TASK_TYPE_COMPLETION:
199            case SERVER_TASK_TYPE_INFILL:
200                return true;
201            default:
202                return false;
203        }
204    }
205
206    static task_params params_from_json_cmpl(
207        const llama_vocab * vocab,
208        const common_params & params_base,
209        const int n_ctx_slot,
210        const json & data);
211
212    // utility function
213    static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
214        std::unordered_set<int> ids(tasks.size());
215        for (size_t i = 0; i < tasks.size(); i++) {
216            ids.insert(tasks[i].id);
217            for (auto & child : tasks[i].child_tasks) {
218                ids.insert(child.id);
219            }
220        }
221        return ids;
222    }
223
224    void add_child(int id_parent, int id_child) {
225        server_task copy;
226
227        copy.id        = id_child;
228        copy.id_parent = id_parent;
229        copy.params    = params;
230        copy.type      = type;
231        copy.tokens    = tokens.clone();
232        copy.id_slot   = -1; // child tasks cannot specify slot
233
234        // use different sampling seed for each child
235        // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
236        if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
237            copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
238        }
239
240        child_tasks.push_back(std::move(copy));
241    }
242
243    // the task will be moved into queue, then onto slots
244    // however, the state must be kept by caller (e.g., HTTP thread)
245    task_result_state create_state() const {
246        return task_result_state(params.chat_parser_params);
247    }
248
249    bool is_parent() const {
250        return child_tasks.size() > 0;
251    }
252
253    bool is_child() const {
254        return id_parent != -1;
255    }
256};
257
258struct result_timings {
259    int32_t cache_n = -1;
260
261    int32_t prompt_n = -1;
262    double prompt_ms;
263    double prompt_per_token_ms;
264    double prompt_per_second;
265
266    int32_t predicted_n = -1;
267    double predicted_ms;
268    double predicted_per_token_ms;
269    double predicted_per_second;
270
271    // Optional speculative metrics - only included when > 0
272    int32_t draft_n = 0;
273    int32_t draft_n_accepted = 0;
274
275    json to_json() const;
276};
277
278struct result_prompt_progress {
279    int32_t total = 0;
280    int32_t cache = 0;
281    int32_t processed = 0;
282    int64_t time_ms = 0;
283
284    json to_json() const;
285};
286
287struct server_task_result {
288    int id           = -1;
289    int id_slot      = -1;
290
291    // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
292    size_t index = 0; // to be used for batched tasks
293
294    virtual bool is_error() {
295        // only used by server_task_result_error
296        return false;
297    }
298    virtual bool is_stop() {
299        // only used by server_task_result_cmpl_*
300        return true;
301    }
302    virtual void update(task_result_state &) {
303        // only used by server_task_result_cmpl_*
304    }
305    virtual json to_json() = 0;
306    virtual ~server_task_result() = default;
307};
308
309// using shared_ptr for polymorphism of server_task_result
310using server_task_result_ptr = std::unique_ptr<server_task_result>;
311
312struct completion_token_output {
313    llama_token tok;
314    float prob;
315    std::string text_to_send;
316    struct prob_info {
317        llama_token tok;
318        std::string txt;
319        float prob;
320    };
321    std::vector<prob_info> probs;
322
323    json to_json(bool post_sampling_probs) const;
324
325    static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
326
327    static float logarithm(float x);
328
329    static std::vector<unsigned char> str_to_bytes(const std::string & str);
330
331};
332
333struct server_task_result_cmpl_final : server_task_result {
334    std::string content;
335    llama_tokens tokens;
336
337    bool stream;
338    bool include_usage;
339    result_timings timings;
340    std::string prompt;
341
342    bool truncated;
343    int32_t n_decoded;
344    int32_t n_prompt_tokens;
345    int32_t n_tokens_cached;
346    bool has_new_line;
347    std::string stopping_word;
348    stop_type stop = STOP_TYPE_NONE;
349
350    bool post_sampling_probs;
351    std::vector<completion_token_output> probs_output;
352    std::vector<std::string>  response_fields;
353
354    task_params generation_params;
355
356    // response formatting
357    bool               verbose  = false;
358    task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
359    std::string        oaicompat_model;
360    std::string        oaicompat_cmpl_id;
361    common_chat_msg    oaicompat_msg; // to be populated by update()
362
363    std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
364    bool is_updated = false;
365
366    // for OpenAI Responses API
367    std::string oai_resp_id;
368    std::string oai_resp_reasoning_id;
369    std::string oai_resp_message_id;
370
371    virtual bool is_stop() override {
372        return true; // in stream mode, final responses are considered stop
373    }
374
375    virtual json to_json() override;
376
377    virtual void update(task_result_state & state) override {
378        is_updated = true;
379        oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
380
381        oai_resp_id = state.oai_resp_id;
382        oai_resp_reasoning_id = state.oai_resp_reasoning_id;
383        oai_resp_message_id = state.oai_resp_message_id;
384    }
385
386    json to_json_non_oaicompat();
387
388    json to_json_oaicompat();
389
390    json to_json_oaicompat_chat();
391
392    json to_json_oaicompat_chat_stream();
393
394    json to_json_oaicompat_resp();
395
396    json to_json_oaicompat_resp_stream();
397
398    json to_json_anthropic();
399
400    json to_json_anthropic_stream();
401};
402
403struct server_task_result_cmpl_partial : server_task_result {
404    std::string  content;
405    llama_tokens tokens;
406
407    int32_t n_decoded;
408    int32_t n_prompt_tokens;
409
410    bool post_sampling_probs;
411    bool is_progress = false;
412    completion_token_output prob_output;
413    result_timings timings;
414    result_prompt_progress progress;
415
416    // response formatting
417    bool               verbose  = false;
418    task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
419    std::string        oaicompat_model;
420    std::string        oaicompat_cmpl_id;
421    std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
422    bool is_updated = false;
423
424    // Streaming state copied from task_result_state for this chunk
425    bool thinking_block_started = false;
426    bool text_block_started     = false;
427
428    // for OpenAI Responses API
429    std::string oai_resp_id;
430    std::string oai_resp_reasoning_id;
431    std::string oai_resp_message_id;
432    std::string oai_resp_fc_id;
433
434    // for Anthropic API: track if any reasoning content has been generated
435    bool anthropic_has_reasoning = false;
436
437    virtual bool is_stop() override {
438        return false; // in stream mode, partial responses are not considered stop
439    }
440
441    virtual void update(task_result_state & state) override;
442
443    virtual json to_json() override;
444
445    json to_json_non_oaicompat();
446
447    json to_json_oaicompat();
448
449    json to_json_oaicompat_chat();
450
451    json to_json_oaicompat_resp();
452
453    json to_json_anthropic();
454};
455
456struct server_task_result_embd : server_task_result {
457    std::vector<std::vector<float>> embedding;
458
459    int32_t n_tokens;
460
461    // response formatting
462    task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
463
464    virtual json to_json() override;
465
466    json to_json_non_oaicompat();
467
468    json to_json_oaicompat();
469};
470
471struct server_task_result_rerank : server_task_result {
472    float score = -1e6;
473
474    int32_t n_tokens;
475
476    virtual json to_json() override;
477};
478
479struct server_task_result_error : server_task_result {
480    error_type err_type = ERROR_TYPE_SERVER;
481    std::string err_msg;
482
483    // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
484    int32_t n_prompt_tokens = 0;
485    int32_t n_ctx           = 0;
486
487    virtual bool is_error() override {
488        return true;
489    }
490
491    virtual json to_json() override;
492};
493
494struct server_task_result_metrics : server_task_result {
495    int n_idle_slots;
496    int n_processing_slots;
497    int n_tasks_deferred;
498    int64_t t_start;
499
500    // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
501    uint64_t n_prompt_tokens_processed_total = 0;
502    uint64_t t_prompt_processing_total       = 0;
503    uint64_t n_tokens_predicted_total        = 0;
504    uint64_t t_tokens_generation_total       = 0;
505
506    uint64_t n_tokens_max = 0;
507
508    uint64_t n_prompt_tokens_processed = 0;
509    uint64_t t_prompt_processing       = 0;
510
511    uint64_t n_tokens_predicted  = 0;
512    uint64_t t_tokens_generation = 0;
513
514    uint64_t n_decode_total     = 0;
515    uint64_t n_busy_slots_total = 0;
516
517    // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
518    // therefore, we use json to temporarily store the slot.to_json() result
519    json slots_data = json::array();
520
521    virtual json to_json() override;
522};
523
524struct server_task_result_slot_save_load : server_task_result {
525    std::string filename;
526    bool is_save; // true = save, false = load
527
528    size_t n_tokens;
529    size_t n_bytes;
530    double t_ms;
531
532    virtual json to_json() override;
533};
534
535struct server_task_result_slot_erase : server_task_result {
536    size_t n_erased;
537
538    virtual json to_json() override;
539};
540
541struct server_task_result_get_lora : server_task_result {
542    struct lora {
543        common_adapter_lora_info info;
544        std::string  alora_invocation_string;
545        llama_tokens alora_invocation_tokens;
546    };
547    std::vector<lora> loras;
548
549    virtual json to_json() override;
550};
551
552struct server_task_result_apply_lora : server_task_result {
553    virtual json to_json() override;
554};
555
556struct server_prompt_checkpoint {
557    llama_pos pos_min;
558    llama_pos pos_max;
559
560    std::vector<uint8_t> data;
561
562    size_t size() const {
563        return data.size();
564    }
565};
566
567struct server_prompt {
568    server_tokens tokens;
569
570    std::vector<uint8_t> data;
571
572    std::list<server_prompt_checkpoint> checkpoints;
573
574    size_t size() const {
575        size_t res = data.size();
576
577        for (const auto & checkpoint : checkpoints) {
578            res += checkpoint.size();
579        }
580
581        return res;
582    }
583
584    int n_tokens() const {
585        return tokens.size();
586    }
587
588    server_prompt clone() const {
589        return server_prompt {
590            tokens.clone(),
591            data,
592            checkpoints
593        };
594    }
595};
596
597struct server_prompt_cache {
598    server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
599        this->limit_size   = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
600        this->limit_tokens = limit_tokens;
601    }
602
603    std::list<server_prompt> states;
604
605    // in bytes, 0 = no limit
606    size_t limit_size = 0;
607
608    // in tokens, 0 = no limit
609    size_t limit_tokens = 0;
610
611    size_t size() const;
612
613    size_t n_tokens() const;
614
615    server_prompt * alloc(const server_prompt & prompt, size_t state_size);
616
617    bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
618
619    void update();
620};