1#pragma once
2
3#include "common.h"
4#include "llama.h"
5
6#include <string>
7#include <unordered_set>
8#include <list>
9#include <map>
10
11// TODO: prevent including the whole server-common.h as we only use server_tokens
12#include "server-common.h"
13
14using json = nlohmann::ordered_json;
15
16enum server_task_type {
17 SERVER_TASK_TYPE_COMPLETION,
18 SERVER_TASK_TYPE_EMBEDDING,
19 SERVER_TASK_TYPE_RERANK,
20 SERVER_TASK_TYPE_INFILL,
21 SERVER_TASK_TYPE_CANCEL,
22 SERVER_TASK_TYPE_NEXT_RESPONSE,
23 SERVER_TASK_TYPE_METRICS,
24 SERVER_TASK_TYPE_SLOT_SAVE,
25 SERVER_TASK_TYPE_SLOT_RESTORE,
26 SERVER_TASK_TYPE_SLOT_ERASE,
27 SERVER_TASK_TYPE_GET_LORA,
28 SERVER_TASK_TYPE_SET_LORA,
29};
30
31// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
32enum task_response_type {
33 TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
34 TASK_RESPONSE_TYPE_OAI_CHAT,
35 TASK_RESPONSE_TYPE_OAI_CMPL,
36 TASK_RESPONSE_TYPE_OAI_RESP,
37 TASK_RESPONSE_TYPE_OAI_EMBD,
38 TASK_RESPONSE_TYPE_ANTHROPIC,
39};
40
41enum stop_type {
42 STOP_TYPE_NONE,
43 STOP_TYPE_EOS,
44 STOP_TYPE_WORD,
45 STOP_TYPE_LIMIT,
46};
47
48struct task_params {
49 bool stream = true;
50 bool include_usage = false;
51 bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
52 bool return_tokens = false;
53 bool return_progress = false;
54
55 int32_t n_keep = 0; // number of tokens to keep from initial prompt
56 int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
57 int32_t n_predict = -1; // new tokens to predict
58 int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
59 int32_t n_cmpl = 1; // number of completions to generate from this prompt
60
61 int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
62
63 int64_t t_max_prompt_ms = -1; // TODO: implement
64 int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
65
66 std::map<int, float> lora; // mapping adapter ID -> scale
67
68 std::vector<std::string> antiprompt;
69 std::vector<std::string> response_fields;
70
71 bool timings_per_token = false;
72 bool post_sampling_probs = false;
73
74 struct common_params_sampling sampling;
75 struct common_params_speculative speculative;
76
77 // response formatting
78 bool verbose = false;
79 task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
80 std::string oaicompat_model;
81 std::string oaicompat_cmpl_id;
82
83 // per-request parameters for chat parsing
84 common_chat_parser_params chat_parser_params;
85
86 // Embeddings
87 int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
88
89 json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
90 json to_json(bool only_metrics = false) const;
91};
92
93// struct for tracking the state of a task (e.g., for streaming)
94struct task_result_state {
95 // tracking diffs for partial tool calls
96 std::vector<common_chat_msg_diff> diffs;
97 common_chat_parser_params chat_parser_params;
98 common_chat_msg chat_msg;
99 std::string generated_text; // append new chunks of generated text here
100 std::vector<std::string> generated_tool_call_ids;
101
102 // for OpenAI Responses and Anthropic streaming API:
103 // track output item / content block state across chunks
104 bool thinking_block_started = false;
105 bool text_block_started = false;
106
107 // for OpenAI Responses streaming API
108 const std::string oai_resp_id;
109 const std::string oai_resp_reasoning_id;
110 const std::string oai_resp_message_id;
111 std::string oai_resp_fc_id; // function call ID for current args delta
112
113 task_result_state(const common_chat_parser_params & chat_parser_params)
114 : chat_parser_params(chat_parser_params)
115 , oai_resp_id("resp_" + random_string())
116 , oai_resp_reasoning_id("rs_" + random_string())
117 , oai_resp_message_id("msg_" + random_string()) {}
118
119 // parse partial tool calls and update the internal state
120 common_chat_msg update_chat_msg(
121 const std::string & text_added,
122 bool is_partial,
123 std::vector<common_chat_msg_diff> & diffs);
124};
125
126struct server_task {
127 int id = -1; // to be filled by server_queue
128
129 // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
130 size_t index = 0; // used when there are multiple prompts (batch request)
131
132 // used by SERVER_TASK_TYPE_CANCEL
133 int id_target = -1;
134 int id_slot = -1;
135
136 // used by parallel sampling (multiple completions from same prompt)
137 int id_parent = -1;
138 // temporary store of child tasks for scheduling
139 // note: accessing to elements is invalid after the task is moved to server_slot
140 std::vector<server_task> child_tasks;
141
142 // used by SERVER_TASK_TYPE_INFERENCE
143 task_params params;
144 server_tokens tokens;
145
146 // only used by CLI, this allow tokenizing CLI inputs on server side
147 // we need this because mtmd_context and vocab are not accessible outside of server_context
148 bool cli = false;
149 std::string cli_prompt;
150 std::vector<raw_buffer> cli_files;
151
152 server_task_type type;
153
154 // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
155 struct slot_action {
156 int id_slot;
157 std::string filename;
158 std::string filepath;
159 };
160 slot_action slot_action;
161
162 // used by SERVER_TASK_TYPE_METRICS
163 bool metrics_reset_bucket = false;
164
165 // used by SERVER_TASK_TYPE_SET_LORA
166 std::map<int, float> set_lora; // mapping adapter ID -> scale
167
168 server_task() = default;
169
170 server_task(server_task_type type) : type(type) {}
171
172 int32_t n_tokens() const {
173 return tokens.size();
174 }
175
176 bool need_embd() const {
177 switch (type) {
178 case SERVER_TASK_TYPE_EMBEDDING:
179 case SERVER_TASK_TYPE_RERANK:
180 return true;
181 default:
182 return false;
183 }
184 }
185
186 bool need_logits() const {
187 switch (type) {
188 case SERVER_TASK_TYPE_COMPLETION:
189 case SERVER_TASK_TYPE_INFILL:
190 return true;
191 default:
192 return false;
193 }
194 }
195
196 bool need_sampling() const {
197 switch (type) {
198 case SERVER_TASK_TYPE_COMPLETION:
199 case SERVER_TASK_TYPE_INFILL:
200 return true;
201 default:
202 return false;
203 }
204 }
205
206 static task_params params_from_json_cmpl(
207 const llama_vocab * vocab,
208 const common_params & params_base,
209 const int n_ctx_slot,
210 const json & data);
211
212 // utility function
213 static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
214 std::unordered_set<int> ids(tasks.size());
215 for (size_t i = 0; i < tasks.size(); i++) {
216 ids.insert(tasks[i].id);
217 for (auto & child : tasks[i].child_tasks) {
218 ids.insert(child.id);
219 }
220 }
221 return ids;
222 }
223
224 void add_child(int id_parent, int id_child) {
225 server_task copy;
226
227 copy.id = id_child;
228 copy.id_parent = id_parent;
229 copy.params = params;
230 copy.type = type;
231 copy.tokens = tokens.clone();
232 copy.id_slot = -1; // child tasks cannot specify slot
233
234 // use different sampling seed for each child
235 // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
236 if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
237 copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
238 }
239
240 child_tasks.push_back(std::move(copy));
241 }
242
243 // the task will be moved into queue, then onto slots
244 // however, the state must be kept by caller (e.g., HTTP thread)
245 task_result_state create_state() const {
246 return task_result_state(params.chat_parser_params);
247 }
248
249 bool is_parent() const {
250 return child_tasks.size() > 0;
251 }
252
253 bool is_child() const {
254 return id_parent != -1;
255 }
256};
257
258struct result_timings {
259 int32_t cache_n = -1;
260
261 int32_t prompt_n = -1;
262 double prompt_ms;
263 double prompt_per_token_ms;
264 double prompt_per_second;
265
266 int32_t predicted_n = -1;
267 double predicted_ms;
268 double predicted_per_token_ms;
269 double predicted_per_second;
270
271 // Optional speculative metrics - only included when > 0
272 int32_t draft_n = 0;
273 int32_t draft_n_accepted = 0;
274
275 json to_json() const;
276};
277
278struct result_prompt_progress {
279 int32_t total = 0;
280 int32_t cache = 0;
281 int32_t processed = 0;
282 int64_t time_ms = 0;
283
284 json to_json() const;
285};
286
287struct server_task_result {
288 int id = -1;
289 int id_slot = -1;
290
291 // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
292 size_t index = 0; // to be used for batched tasks
293
294 virtual bool is_error() {
295 // only used by server_task_result_error
296 return false;
297 }
298 virtual bool is_stop() {
299 // only used by server_task_result_cmpl_*
300 return true;
301 }
302 virtual void update(task_result_state &) {
303 // only used by server_task_result_cmpl_*
304 }
305 virtual json to_json() = 0;
306 virtual ~server_task_result() = default;
307};
308
309// using shared_ptr for polymorphism of server_task_result
310using server_task_result_ptr = std::unique_ptr<server_task_result>;
311
312struct completion_token_output {
313 llama_token tok;
314 float prob;
315 std::string text_to_send;
316 struct prob_info {
317 llama_token tok;
318 std::string txt;
319 float prob;
320 };
321 std::vector<prob_info> probs;
322
323 json to_json(bool post_sampling_probs) const;
324
325 static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
326
327 static float logarithm(float x);
328
329 static std::vector<unsigned char> str_to_bytes(const std::string & str);
330
331};
332
333struct server_task_result_cmpl_final : server_task_result {
334 std::string content;
335 llama_tokens tokens;
336
337 bool stream;
338 bool include_usage;
339 result_timings timings;
340 std::string prompt;
341
342 bool truncated;
343 int32_t n_decoded;
344 int32_t n_prompt_tokens;
345 int32_t n_tokens_cached;
346 bool has_new_line;
347 std::string stopping_word;
348 stop_type stop = STOP_TYPE_NONE;
349
350 bool post_sampling_probs;
351 std::vector<completion_token_output> probs_output;
352 std::vector<std::string> response_fields;
353
354 task_params generation_params;
355
356 // response formatting
357 bool verbose = false;
358 task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
359 std::string oaicompat_model;
360 std::string oaicompat_cmpl_id;
361 common_chat_msg oaicompat_msg; // to be populated by update()
362
363 std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
364 bool is_updated = false;
365
366 // for OpenAI Responses API
367 std::string oai_resp_id;
368 std::string oai_resp_reasoning_id;
369 std::string oai_resp_message_id;
370
371 virtual bool is_stop() override {
372 return true; // in stream mode, final responses are considered stop
373 }
374
375 virtual json to_json() override;
376
377 virtual void update(task_result_state & state) override {
378 is_updated = true;
379 oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
380
381 oai_resp_id = state.oai_resp_id;
382 oai_resp_reasoning_id = state.oai_resp_reasoning_id;
383 oai_resp_message_id = state.oai_resp_message_id;
384 }
385
386 json to_json_non_oaicompat();
387
388 json to_json_oaicompat();
389
390 json to_json_oaicompat_chat();
391
392 json to_json_oaicompat_chat_stream();
393
394 json to_json_oaicompat_resp();
395
396 json to_json_oaicompat_resp_stream();
397
398 json to_json_anthropic();
399
400 json to_json_anthropic_stream();
401};
402
403struct server_task_result_cmpl_partial : server_task_result {
404 std::string content;
405 llama_tokens tokens;
406
407 int32_t n_decoded;
408 int32_t n_prompt_tokens;
409
410 bool post_sampling_probs;
411 bool is_progress = false;
412 completion_token_output prob_output;
413 result_timings timings;
414 result_prompt_progress progress;
415
416 // response formatting
417 bool verbose = false;
418 task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
419 std::string oaicompat_model;
420 std::string oaicompat_cmpl_id;
421 std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
422 bool is_updated = false;
423
424 // Streaming state copied from task_result_state for this chunk
425 bool thinking_block_started = false;
426 bool text_block_started = false;
427
428 // for OpenAI Responses API
429 std::string oai_resp_id;
430 std::string oai_resp_reasoning_id;
431 std::string oai_resp_message_id;
432 std::string oai_resp_fc_id;
433
434 // for Anthropic API: track if any reasoning content has been generated
435 bool anthropic_has_reasoning = false;
436
437 virtual bool is_stop() override {
438 return false; // in stream mode, partial responses are not considered stop
439 }
440
441 virtual void update(task_result_state & state) override;
442
443 virtual json to_json() override;
444
445 json to_json_non_oaicompat();
446
447 json to_json_oaicompat();
448
449 json to_json_oaicompat_chat();
450
451 json to_json_oaicompat_resp();
452
453 json to_json_anthropic();
454};
455
456struct server_task_result_embd : server_task_result {
457 std::vector<std::vector<float>> embedding;
458
459 int32_t n_tokens;
460
461 // response formatting
462 task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
463
464 virtual json to_json() override;
465
466 json to_json_non_oaicompat();
467
468 json to_json_oaicompat();
469};
470
471struct server_task_result_rerank : server_task_result {
472 float score = -1e6;
473
474 int32_t n_tokens;
475
476 virtual json to_json() override;
477};
478
479struct server_task_result_error : server_task_result {
480 error_type err_type = ERROR_TYPE_SERVER;
481 std::string err_msg;
482
483 // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
484 int32_t n_prompt_tokens = 0;
485 int32_t n_ctx = 0;
486
487 virtual bool is_error() override {
488 return true;
489 }
490
491 virtual json to_json() override;
492};
493
494struct server_task_result_metrics : server_task_result {
495 int n_idle_slots;
496 int n_processing_slots;
497 int n_tasks_deferred;
498 int64_t t_start;
499
500 // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
501 uint64_t n_prompt_tokens_processed_total = 0;
502 uint64_t t_prompt_processing_total = 0;
503 uint64_t n_tokens_predicted_total = 0;
504 uint64_t t_tokens_generation_total = 0;
505
506 uint64_t n_tokens_max = 0;
507
508 uint64_t n_prompt_tokens_processed = 0;
509 uint64_t t_prompt_processing = 0;
510
511 uint64_t n_tokens_predicted = 0;
512 uint64_t t_tokens_generation = 0;
513
514 uint64_t n_decode_total = 0;
515 uint64_t n_busy_slots_total = 0;
516
517 // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
518 // therefore, we use json to temporarily store the slot.to_json() result
519 json slots_data = json::array();
520
521 virtual json to_json() override;
522};
523
524struct server_task_result_slot_save_load : server_task_result {
525 std::string filename;
526 bool is_save; // true = save, false = load
527
528 size_t n_tokens;
529 size_t n_bytes;
530 double t_ms;
531
532 virtual json to_json() override;
533};
534
535struct server_task_result_slot_erase : server_task_result {
536 size_t n_erased;
537
538 virtual json to_json() override;
539};
540
541struct server_task_result_get_lora : server_task_result {
542 struct lora {
543 common_adapter_lora_info info;
544 std::string alora_invocation_string;
545 llama_tokens alora_invocation_tokens;
546 };
547 std::vector<lora> loras;
548
549 virtual json to_json() override;
550};
551
552struct server_task_result_apply_lora : server_task_result {
553 virtual json to_json() override;
554};
555
556struct server_prompt_checkpoint {
557 llama_pos pos_min;
558 llama_pos pos_max;
559
560 std::vector<uint8_t> data;
561
562 size_t size() const {
563 return data.size();
564 }
565};
566
567struct server_prompt {
568 server_tokens tokens;
569
570 std::vector<uint8_t> data;
571
572 std::list<server_prompt_checkpoint> checkpoints;
573
574 size_t size() const {
575 size_t res = data.size();
576
577 for (const auto & checkpoint : checkpoints) {
578 res += checkpoint.size();
579 }
580
581 return res;
582 }
583
584 int n_tokens() const {
585 return tokens.size();
586 }
587
588 server_prompt clone() const {
589 return server_prompt {
590 tokens.clone(),
591 data,
592 checkpoints
593 };
594 }
595};
596
597struct server_prompt_cache {
598 server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
599 this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
600 this->limit_tokens = limit_tokens;
601 }
602
603 std::list<server_prompt> states;
604
605 // in bytes, 0 = no limit
606 size_t limit_size = 0;
607
608 // in tokens, 0 = no limit
609 size_t limit_tokens = 0;
610
611 size_t size() const;
612
613 size_t n_tokens() const;
614
615 server_prompt * alloc(const server_prompt & prompt, size_t state_size);
616
617 bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
618
619 void update();
620};