1#pragma once
2
3#include "common.h"
4#include "log.h"
5#include "llama.h"
6#include "chat.h"
7#include "mtmd.h"
8
9#define JSON_ASSERT GGML_ASSERT
10#include <nlohmann/json.hpp>
11
12#include <string>
13#include <vector>
14#include <cinttypes>
15
16using json = nlohmann::ordered_json;
17
18#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
19#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
20#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
21#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
22#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
23
24#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
25#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
26#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
27#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
28#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
29
30using raw_buffer = std::vector<uint8_t>;
31
32template <typename T>
33static T json_value(const json & body, const std::string & key, const T & default_value) {
34 // Fallback null to default value
35 if (body.contains(key) && !body.at(key).is_null()) {
36 try {
37 return body.at(key);
38 } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
39 LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
40 return default_value;
41 }
42 } else {
43 return default_value;
44 }
45}
46
47// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
48enum error_type {
49 ERROR_TYPE_INVALID_REQUEST,
50 ERROR_TYPE_AUTHENTICATION,
51 ERROR_TYPE_SERVER,
52 ERROR_TYPE_NOT_FOUND,
53 ERROR_TYPE_PERMISSION,
54 ERROR_TYPE_UNAVAILABLE, // custom error
55 ERROR_TYPE_NOT_SUPPORTED, // custom error
56 ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
57};
58
59// thin wrapper around common_grammar_trigger with (de)serialization functions
60struct server_grammar_trigger {
61 common_grammar_trigger value;
62
63 server_grammar_trigger() = default;
64 server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
65 server_grammar_trigger(const json & in) {
66 value.type = (common_grammar_trigger_type) in.at("type").get<int>();
67 value.value = in.at("value").get<std::string>();
68 if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
69 value.token = (llama_token) in.at("token").get<int>();
70 }
71 }
72
73 json to_json() const {
74 json out {
75 {"type", (int) value.type},
76 {"value", value.value},
77 };
78 if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
79 out["token"] = (int) value.token;
80 }
81 return out;
82 }
83};
84
85json format_error_response(const std::string & message, const enum error_type type);
86
87//
88// random string / id
89//
90
91std::string random_string();
92std::string gen_chatcmplid();
93std::string gen_tool_call_id();
94
95//
96// lora utils
97//
98
99// check whether the given lora set has only aloras activated (empty => false)
100bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
101
102// if the two sets of loras are different, they require a cache clear unless the
103// change is only from aloras to aloras.
104bool lora_should_clear_cache(
105 const std::vector<common_adapter_lora_info> & current,
106 const std::vector<common_adapter_lora_info> & next);
107
108std::map<int, float> parse_lora_request(const json & data);
109
110bool are_lora_equal(
111 const std::vector<common_adapter_lora_info> & l1,
112 const std::vector<common_adapter_lora_info> & l2);
113
114// get the ids of all enabled loras
115std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
116
117//
118// server_tokens
119//
120
121/**
122 * server_tokens is a helper to manage the input tokens and image for the server.
123 * it is made this way to simplify the logic of KV cache management.
124 */
125struct server_tokens {
126 bool has_mtmd = false;
127
128private: // disallow accessing these members directly, risking out-of-sync
129
130 // map a **start** index in tokens to the image chunk
131 // note: the order need to be in-sync with tokens
132 std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
133
134 // list of tokens
135 // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
136 // otherwise, it is a normal text token
137 // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
138 // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
139 llama_tokens tokens;
140
141 // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
142 // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
143 // idx 0 1 2 3 4 5 6 7 8 9 10
144 // pos 0 1 2 3 4 5 5 5 7 7 7
145 // map_idx_to_media will contain: {5, img0}, {8, img1}
146
147public:
148 server_tokens() = default;
149 ~server_tokens() = default;
150
151 // Prevent copying
152 // TODO: server_tokens should be copyable - remove this:
153 server_tokens(const server_tokens&) = delete;
154 server_tokens& operator=(const server_tokens&) = delete;
155
156 // Allow moving (usually implicitly generated if members are movable)
157 server_tokens(server_tokens&&) = default;
158 server_tokens& operator=(server_tokens&&) = default;
159
160 // Allow accessing elements using [] operator
161 llama_token operator[](size_t index) { return tokens[index]; }
162 const llama_token& operator[](size_t index) const { return tokens[index]; }
163
164 server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
165 server_tokens(const llama_tokens & tokens, bool has_mtmd);
166
167 // for debugging
168 std::string str() const;
169
170 llama_pos pos_next() const;
171 const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
172
173 void push_back(llama_token tok);
174
175 // will create a copy of the chunk if it contains non-text data
176 void push_back(const mtmd_input_chunk * chunk);
177
178 // appends server tokens, updates the media map. copies media chunks.
179 void push_back(server_tokens & tokens);
180
181 // for compatibility with context shift and prompt truncation
182 void insert(const llama_tokens & inp_tokens);
183
184 // for compatibility with speculative decoding, ctx shift, slot save/load
185 const llama_tokens & get_text_tokens() const;
186
187 // for compatibility with speculative decoding
188 void set_token(llama_pos pos, llama_token id);
189
190 size_t size() const { return tokens.size(); }
191
192 bool empty() const { return tokens.empty(); }
193
194 void clear() {
195 map_idx_to_media.clear();
196 tokens.clear();
197 }
198
199 void keep_first(size_t n);
200
201 std::string detokenize(const llama_context * ctx, bool special) const;
202
203 size_t get_common_prefix(const server_tokens & b) const;
204
205 // make sure all text tokens are within the vocab range
206 bool validate(const struct llama_context * ctx) const;
207
208 // encode and decode the image chunk
209 int32_t process_chunk(
210 llama_context * ctx,
211 mtmd_context * mctx,
212 size_t idx,
213 llama_pos pos,
214 int32_t seq_id,
215 size_t & n_tokens_out) const;
216
217 server_tokens clone() const;
218};
219
220
221//
222// tokenizer and input processing utils
223//
224
225bool json_is_array_of_numbers(const json & data);
226
227// is array having BOTH numbers & strings?
228bool json_is_array_of_mixed_numbers_strings(const json & data);
229
230// does array have any individual integers/tokens?
231bool json_is_array_and_contains_numbers(const json & data);
232
233// get value by path(key1 / key2)
234json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
235
236/**
237 * this handles 2 cases:
238 * - only string, example: "string"
239 * - mixed string and tokens, example: [12, 34, "string", 56, 78]
240 */
241llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
242
243// return the last index of character that can form a valid string
244// if the last character is potentially cut in half, return the index before the cut
245// if validate_utf8(text) == text.size(), then the whole text is valid utf8
246size_t validate_utf8(const std::string& text);
247
248// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
249server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
250
251/**
252 * break the input "prompt" object into multiple prompt if needed, then tokenize them
253 * this supports these cases:
254 * - "prompt": "string"
255 * - "prompt": [12, 34, 56]
256 * - "prompt": [12, 34, "string", 56, 78]
257 * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
258 * and multiple prompts (multi-tasks):
259 * - "prompt": ["string1", "string2"]
260 * - "prompt": ["string1", [12, 34, 56]]
261 * - "prompt": [[12, 34, 56], [78, 90, 12]]
262 * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
263 */
264std::vector<server_tokens> tokenize_input_prompts(
265 const llama_vocab * vocab,
266 mtmd_context * mctx,
267 const json & json_prompt,
268 bool add_special,
269 bool parse_special);
270
271//
272// OAI utils
273//
274
275// global server parameters for chat formatting / parsing
276struct server_chat_params {
277 bool use_jinja;
278 bool prefill_assistant;
279 common_reasoning_format reasoning_format;
280 std::map<std::string, std::string> chat_template_kwargs; // mapping key --> json value
281 common_chat_templates_ptr tmpls;
282 bool allow_image;
283 bool allow_audio;
284 bool enable_thinking = true;
285 std::string media_path;
286};
287
288// used by /completions endpoint
289json oaicompat_completion_params_parse(const json & body);
290
291// used by /chat/completions endpoint
292json oaicompat_chat_params_parse(
293 json & body, /* openai api json semantics */
294 const server_chat_params & opt,
295 std::vector<raw_buffer> & out_files);
296
297// convert OpenAI Responses API format to OpenAI Chat Completions API format
298json convert_responses_to_chatcmpl(const json & body);
299
300// convert Anthropic Messages API format to OpenAI Chat Completions API format
301json convert_anthropic_to_oai(const json & body);
302
303// TODO: move it to server-task.cpp
304json format_embeddings_response_oaicompat(
305 const json & request,
306 const std::string & model_name,
307 const json & embeddings,
308 bool use_base64 = false);
309
310// TODO: move it to server-task.cpp
311json format_response_rerank(
312 const json & request,
313 const std::string & model_name,
314 const json & ranks,
315 bool is_tei_format,
316 std::vector<std::string> & texts,
317 int top_n);
318
319//
320// other utils
321//
322
323std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
324
325std::string safe_json_to_str(const json & data);
326
327std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
328std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
329
330// format incomplete utf-8 multibyte character for output
331std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
332
333// format server-sent event (SSE), return the formatted string to send
334// note: if data is a json array, it will be sent as multiple events, one per item
335std::string format_oai_sse(const json & data);
336
337std::string format_oai_resp_sse(const json & data);
338
339// format Anthropic-style SSE with event types
340std::string format_anthropic_sse(const json & data);
341
342bool is_valid_utf8(const std::string & str);
343
344//
345// formatting output responses
346// TODO: move these to server-task.cpp
347//
348
349llama_tokens format_prompt_infill(
350 const llama_vocab * vocab,
351 const json & input_prefix,
352 const json & input_suffix,
353 const json & input_extra,
354 const int n_batch,
355 const int n_predict,
356 const int n_ctx,
357 const bool spm_infill,
358 const llama_tokens & tokens_prompt);
359
360// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
361server_tokens format_prompt_rerank(
362 const struct llama_model * model,
363 const struct llama_vocab * vocab,
364 mtmd_context * mctx,
365 const std::string & query,
366 const std::string & doc);