1#include "server-context.h"
2#include "server-http.h"
3#include "server-models.h"
4
5#include "arg.h"
6#include "common.h"
7#include "llama.h"
8#include "log.h"
9
10#include <atomic>
11#include <exception>
12#include <signal.h>
13#include <thread> // for std::thread::hardware_concurrency
14
15#if defined(_WIN32)
16#include <windows.h>
17#endif
18
19static std::function<void(int)> shutdown_handler;
20static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
21
22static inline void signal_handler(int signal) {
23 if (is_terminating.test_and_set()) {
24 // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
25 // this is for better developer experience, we can remove when the server is stable enough
26 fprintf(stderr, "Received second interrupt, terminating immediately.\n");
27 exit(1);
28 }
29
30 shutdown_handler(signal);
31}
32
33// wrapper function that handles exceptions and logs errors
34// this is to make sure handler_t never throws exceptions; instead, it returns an error response
35static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
36 return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
37 std::string message;
38 error_type error;
39 try {
40 return func(req);
41 } catch (const std::invalid_argument & e) {
42 // treat invalid_argument as invalid request (400)
43 error = ERROR_TYPE_INVALID_REQUEST;
44 message = e.what();
45 } catch (const std::exception & e) {
46 // treat other exceptions as server error (500)
47 error = ERROR_TYPE_SERVER;
48 message = e.what();
49 } catch (...) {
50 error = ERROR_TYPE_SERVER;
51 message = "unknown error";
52 }
53
54 auto res = std::make_unique<server_http_res>();
55 res->status = 500;
56 try {
57 json error_data = format_error_response(message, error);
58 res->status = json_value(error_data, "code", 500);
59 res->data = safe_json_to_str({{ "error", error_data }});
60 SRV_WRN("got exception: %s\n", res->data.c_str());
61 } catch (const std::exception & e) {
62 SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
63 res->data = "Internal Server Error";
64 }
65 return res;
66 };
67}
68
69int main(int argc, char ** argv) {
70 // own arguments required by this example
71 common_params params;
72
73 if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
74 return 1;
75 }
76
77 // validate batch size for embeddings
78 // embeddings require all tokens to be processed in a single ubatch
79 // see https://github.com/ggml-org/llama.cpp/issues/12836
80 if (params.embedding && params.n_batch > params.n_ubatch) {
81 LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
82 LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
83 params.n_batch = params.n_ubatch;
84 }
85
86 if (params.n_parallel < 0) {
87 LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
88
89 params.n_parallel = 4;
90 params.kv_unified = true;
91 }
92
93 // for consistency between server router mode and single-model mode, we set the same model name as alias
94 if (params.model_alias.empty() && !params.model.name.empty()) {
95 params.model_alias = params.model.name;
96 }
97
98 common_init();
99
100 // struct that contains llama context and inference
101 server_context ctx_server;
102
103 llama_backend_init();
104 llama_numa_init(params.numa);
105
106 LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
107 LOG_INF("\n");
108 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
109 LOG_INF("\n");
110
111 server_http_context ctx_http;
112 if (!ctx_http.init(params)) {
113 LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
114 return 1;
115 }
116
117 //
118 // Router
119 //
120
121 // register API routes
122 server_routes routes(params, ctx_server);
123
124 bool is_router_server = params.model.path.empty();
125 std::optional<server_models_routes> models_routes{};
126 if (is_router_server) {
127 // setup server instances manager
128 try {
129 models_routes.emplace(params, argc, argv);
130 } catch (const std::exception & e) {
131 LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what());
132 return 1;
133 }
134
135 // proxy handlers
136 // note: routes.get_health stays the same
137 routes.get_metrics = models_routes->proxy_get;
138 routes.post_props = models_routes->proxy_post;
139 routes.get_api_show = models_routes->proxy_get;
140 routes.post_completions = models_routes->proxy_post;
141 routes.post_completions_oai = models_routes->proxy_post;
142 routes.post_chat_completions = models_routes->proxy_post;
143 routes.post_responses_oai = models_routes->proxy_post;
144 routes.post_anthropic_messages = models_routes->proxy_post;
145 routes.post_anthropic_count_tokens = models_routes->proxy_post;
146 routes.post_infill = models_routes->proxy_post;
147 routes.post_embeddings = models_routes->proxy_post;
148 routes.post_embeddings_oai = models_routes->proxy_post;
149 routes.post_rerank = models_routes->proxy_post;
150 routes.post_tokenize = models_routes->proxy_post;
151 routes.post_detokenize = models_routes->proxy_post;
152 routes.post_apply_template = models_routes->proxy_post;
153 routes.get_lora_adapters = models_routes->proxy_get;
154 routes.post_lora_adapters = models_routes->proxy_post;
155 routes.get_slots = models_routes->proxy_get;
156 routes.post_slots = models_routes->proxy_post;
157
158 // custom routes for router
159 routes.get_props = models_routes->get_router_props;
160 routes.get_models = models_routes->get_router_models;
161 ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
162 ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
163 }
164
165 ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
166 ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
167 ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
168 ctx_http.get ("/props", ex_wrapper(routes.get_props));
169 ctx_http.post("/props", ex_wrapper(routes.post_props));
170 ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
171 ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
172 ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
173 ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
174 ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
175 ctx_http.post("/completions", ex_wrapper(routes.post_completions));
176 ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
177 ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
178 ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
179 ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
180 ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
181 ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
182 ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
183 ctx_http.post("/infill", ex_wrapper(routes.post_infill));
184 ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
185 ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
186 ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
187 ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
188 ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
189 ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
190 ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
191 ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
192 ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
193 ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
194 // LoRA adapters hotswap
195 ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
196 ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
197 // Save & load slots
198 ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
199 ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
200
201 //
202 // Start the server
203 //
204
205 std::function<void()> clean_up;
206
207 if (is_router_server) {
208 LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
209
210 clean_up = [&models_routes]() {
211 SRV_INF("%s: cleaning up before exit...\n", __func__);
212 if (models_routes.has_value()) {
213 models_routes->models.unload_all();
214 }
215 llama_backend_free();
216 };
217
218 if (!ctx_http.start()) {
219 clean_up();
220 LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
221 return 1;
222 }
223 ctx_http.is_ready.store(true);
224
225 shutdown_handler = [&](int) {
226 ctx_http.stop();
227 };
228
229 } else {
230 // setup clean up function, to be called before exit
231 clean_up = [&ctx_http, &ctx_server]() {
232 SRV_INF("%s: cleaning up before exit...\n", __func__);
233 ctx_http.stop();
234 ctx_server.terminate();
235 llama_backend_free();
236 };
237
238 // start the HTTP server before loading the model to be able to serve /health requests
239 if (!ctx_http.start()) {
240 clean_up();
241 LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
242 return 1;
243 }
244
245 // load the model
246 LOG_INF("%s: loading model\n", __func__);
247
248 if (!ctx_server.load_model(params)) {
249 clean_up();
250 if (ctx_http.thread.joinable()) {
251 ctx_http.thread.join();
252 }
253 LOG_ERR("%s: exiting due to model loading error\n", __func__);
254 return 1;
255 }
256
257 routes.update_meta(ctx_server);
258 ctx_http.is_ready.store(true);
259
260 LOG_INF("%s: model loaded\n", __func__);
261
262 shutdown_handler = [&](int) {
263 // this will unblock start_loop()
264 ctx_server.terminate();
265 };
266 }
267
268 // TODO: refactor in common/console
269#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
270 struct sigaction sigint_action;
271 sigint_action.sa_handler = signal_handler;
272 sigemptyset (&sigint_action.sa_mask);
273 sigint_action.sa_flags = 0;
274 sigaction(SIGINT, &sigint_action, NULL);
275 sigaction(SIGTERM, &sigint_action, NULL);
276#elif defined (_WIN32)
277 auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
278 return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
279 };
280 SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
281#endif
282
283 if (is_router_server) {
284 LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
285 LOG_INF("%s: NOTE: router mode is experimental\n", __func__);
286 LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__);
287 if (ctx_http.thread.joinable()) {
288 ctx_http.thread.join(); // keep the main thread alive
289 }
290
291 // when the HTTP server stops, clean up and exit
292 clean_up();
293 } else {
294 LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
295 LOG_INF("%s: starting the main loop...\n", __func__);
296
297 // optionally, notify router server that this instance is ready
298 const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
299 std::thread monitor_thread;
300 if (router_port != nullptr) {
301 monitor_thread = server_models::setup_child_server(shutdown_handler);
302 }
303
304 // this call blocks the main thread until queue_tasks.terminate() is called
305 ctx_server.start_loop();
306
307 clean_up();
308 if (ctx_http.thread.joinable()) {
309 ctx_http.thread.join();
310 }
311 if (monitor_thread.joinable()) {
312 monitor_thread.join();
313 }
314
315 auto * ll_ctx = ctx_server.get_llama_context();
316 if (ll_ctx != nullptr) {
317 llama_memory_breakdown_print(ll_ctx);
318 }
319 }
320
321 return 0;
322}