1#include "server-context.h"
  2#include "server-http.h"
  3#include "server-models.h"
  4
  5#include "arg.h"
  6#include "common.h"
  7#include "llama.h"
  8#include "log.h"
  9
 10#include <atomic>
 11#include <exception>
 12#include <signal.h>
 13#include <thread> // for std::thread::hardware_concurrency
 14
 15#if defined(_WIN32)
 16#include <windows.h>
 17#endif
 18
 19static std::function<void(int)> shutdown_handler;
 20static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
 21
 22static inline void signal_handler(int signal) {
 23    if (is_terminating.test_and_set()) {
 24        // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
 25        // this is for better developer experience, we can remove when the server is stable enough
 26        fprintf(stderr, "Received second interrupt, terminating immediately.\n");
 27        exit(1);
 28    }
 29
 30    shutdown_handler(signal);
 31}
 32
 33// wrapper function that handles exceptions and logs errors
 34// this is to make sure handler_t never throws exceptions; instead, it returns an error response
 35static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
 36    return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
 37        std::string message;
 38        error_type error;
 39        try {
 40            return func(req);
 41        } catch (const std::invalid_argument & e) {
 42            // treat invalid_argument as invalid request (400)
 43            error = ERROR_TYPE_INVALID_REQUEST;
 44            message = e.what();
 45        } catch (const std::exception & e) {
 46            // treat other exceptions as server error (500)
 47            error = ERROR_TYPE_SERVER;
 48            message = e.what();
 49        } catch (...) {
 50            error = ERROR_TYPE_SERVER;
 51            message = "unknown error";
 52        }
 53
 54        auto res = std::make_unique<server_http_res>();
 55        res->status = 500;
 56        try {
 57            json error_data = format_error_response(message, error);
 58            res->status = json_value(error_data, "code", 500);
 59            res->data = safe_json_to_str({{ "error", error_data }});
 60            SRV_WRN("got exception: %s\n", res->data.c_str());
 61        } catch (const std::exception & e) {
 62            SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
 63            res->data = "Internal Server Error";
 64        }
 65        return res;
 66    };
 67}
 68
 69int main(int argc, char ** argv) {
 70    // own arguments required by this example
 71    common_params params;
 72
 73    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
 74        return 1;
 75    }
 76
 77    // validate batch size for embeddings
 78    // embeddings require all tokens to be processed in a single ubatch
 79    // see https://github.com/ggml-org/llama.cpp/issues/12836
 80    if (params.embedding && params.n_batch > params.n_ubatch) {
 81        LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
 82        LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
 83        params.n_batch = params.n_ubatch;
 84    }
 85
 86    if (params.n_parallel < 0) {
 87        LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
 88
 89        params.n_parallel = 4;
 90        params.kv_unified = true;
 91    }
 92
 93    // for consistency between server router mode and single-model mode, we set the same model name as alias
 94    if (params.model_alias.empty() && !params.model.name.empty()) {
 95        params.model_alias = params.model.name;
 96    }
 97
 98    common_init();
 99
100    // struct that contains llama context and inference
101    server_context ctx_server;
102
103    llama_backend_init();
104    llama_numa_init(params.numa);
105
106    LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
107    LOG_INF("\n");
108    LOG_INF("%s\n", common_params_get_system_info(params).c_str());
109    LOG_INF("\n");
110
111    server_http_context ctx_http;
112    if (!ctx_http.init(params)) {
113        LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
114        return 1;
115    }
116
117    //
118    // Router
119    //
120
121    // register API routes
122    server_routes routes(params, ctx_server);
123
124    bool is_router_server = params.model.path.empty();
125    std::optional<server_models_routes> models_routes{};
126    if (is_router_server) {
127        // setup server instances manager
128        try {
129            models_routes.emplace(params, argc, argv);
130        } catch (const std::exception & e) {
131            LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what());
132            return 1;
133        }
134
135        // proxy handlers
136        // note: routes.get_health stays the same
137        routes.get_metrics                 = models_routes->proxy_get;
138        routes.post_props                  = models_routes->proxy_post;
139        routes.get_api_show                = models_routes->proxy_get;
140        routes.post_completions            = models_routes->proxy_post;
141        routes.post_completions_oai        = models_routes->proxy_post;
142        routes.post_chat_completions       = models_routes->proxy_post;
143        routes.post_responses_oai          = models_routes->proxy_post;
144        routes.post_anthropic_messages     = models_routes->proxy_post;
145        routes.post_anthropic_count_tokens = models_routes->proxy_post;
146        routes.post_infill                 = models_routes->proxy_post;
147        routes.post_embeddings             = models_routes->proxy_post;
148        routes.post_embeddings_oai         = models_routes->proxy_post;
149        routes.post_rerank                 = models_routes->proxy_post;
150        routes.post_tokenize               = models_routes->proxy_post;
151        routes.post_detokenize             = models_routes->proxy_post;
152        routes.post_apply_template         = models_routes->proxy_post;
153        routes.get_lora_adapters           = models_routes->proxy_get;
154        routes.post_lora_adapters          = models_routes->proxy_post;
155        routes.get_slots                   = models_routes->proxy_get;
156        routes.post_slots                  = models_routes->proxy_post;
157
158        // custom routes for router
159        routes.get_props  = models_routes->get_router_props;
160        routes.get_models = models_routes->get_router_models;
161        ctx_http.post("/models/load",   ex_wrapper(models_routes->post_router_models_load));
162        ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
163    }
164
165    ctx_http.get ("/health",              ex_wrapper(routes.get_health)); // public endpoint (no API key check)
166    ctx_http.get ("/v1/health",           ex_wrapper(routes.get_health)); // public endpoint (no API key check)
167    ctx_http.get ("/metrics",             ex_wrapper(routes.get_metrics));
168    ctx_http.get ("/props",               ex_wrapper(routes.get_props));
169    ctx_http.post("/props",               ex_wrapper(routes.post_props));
170    ctx_http.post("/api/show",            ex_wrapper(routes.get_api_show));
171    ctx_http.get ("/models",              ex_wrapper(routes.get_models)); // public endpoint (no API key check)
172    ctx_http.get ("/v1/models",           ex_wrapper(routes.get_models)); // public endpoint (no API key check)
173    ctx_http.get ("/api/tags",            ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
174    ctx_http.post("/completion",          ex_wrapper(routes.post_completions)); // legacy
175    ctx_http.post("/completions",         ex_wrapper(routes.post_completions));
176    ctx_http.post("/v1/completions",      ex_wrapper(routes.post_completions_oai));
177    ctx_http.post("/chat/completions",    ex_wrapper(routes.post_chat_completions));
178    ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
179    ctx_http.post("/api/chat",            ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
180    ctx_http.post("/v1/responses",        ex_wrapper(routes.post_responses_oai));
181    ctx_http.post("/v1/messages",         ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
182    ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
183    ctx_http.post("/infill",              ex_wrapper(routes.post_infill));
184    ctx_http.post("/embedding",           ex_wrapper(routes.post_embeddings)); // legacy
185    ctx_http.post("/embeddings",          ex_wrapper(routes.post_embeddings));
186    ctx_http.post("/v1/embeddings",       ex_wrapper(routes.post_embeddings_oai));
187    ctx_http.post("/rerank",              ex_wrapper(routes.post_rerank));
188    ctx_http.post("/reranking",           ex_wrapper(routes.post_rerank));
189    ctx_http.post("/v1/rerank",           ex_wrapper(routes.post_rerank));
190    ctx_http.post("/v1/reranking",        ex_wrapper(routes.post_rerank));
191    ctx_http.post("/tokenize",            ex_wrapper(routes.post_tokenize));
192    ctx_http.post("/detokenize",          ex_wrapper(routes.post_detokenize));
193    ctx_http.post("/apply-template",      ex_wrapper(routes.post_apply_template));
194    // LoRA adapters hotswap
195    ctx_http.get ("/lora-adapters",       ex_wrapper(routes.get_lora_adapters));
196    ctx_http.post("/lora-adapters",       ex_wrapper(routes.post_lora_adapters));
197    // Save & load slots
198    ctx_http.get ("/slots",               ex_wrapper(routes.get_slots));
199    ctx_http.post("/slots/:id_slot",      ex_wrapper(routes.post_slots));
200
201    //
202    // Start the server
203    //
204
205    std::function<void()> clean_up;
206
207    if (is_router_server) {
208        LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
209
210        clean_up = [&models_routes]() {
211            SRV_INF("%s: cleaning up before exit...\n", __func__);
212            if (models_routes.has_value()) {
213                models_routes->models.unload_all();
214            }
215            llama_backend_free();
216        };
217
218        if (!ctx_http.start()) {
219            clean_up();
220            LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
221            return 1;
222        }
223        ctx_http.is_ready.store(true);
224
225        shutdown_handler = [&](int) {
226            ctx_http.stop();
227        };
228
229    } else {
230        // setup clean up function, to be called before exit
231        clean_up = [&ctx_http, &ctx_server]() {
232            SRV_INF("%s: cleaning up before exit...\n", __func__);
233            ctx_http.stop();
234            ctx_server.terminate();
235            llama_backend_free();
236        };
237
238        // start the HTTP server before loading the model to be able to serve /health requests
239        if (!ctx_http.start()) {
240            clean_up();
241            LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
242            return 1;
243        }
244
245        // load the model
246        LOG_INF("%s: loading model\n", __func__);
247
248        if (!ctx_server.load_model(params)) {
249            clean_up();
250            if (ctx_http.thread.joinable()) {
251                ctx_http.thread.join();
252            }
253            LOG_ERR("%s: exiting due to model loading error\n", __func__);
254            return 1;
255        }
256
257        routes.update_meta(ctx_server);
258        ctx_http.is_ready.store(true);
259
260        LOG_INF("%s: model loaded\n", __func__);
261
262        shutdown_handler = [&](int) {
263            // this will unblock start_loop()
264            ctx_server.terminate();
265        };
266    }
267
268    // TODO: refactor in common/console
269#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
270    struct sigaction sigint_action;
271    sigint_action.sa_handler = signal_handler;
272    sigemptyset (&sigint_action.sa_mask);
273    sigint_action.sa_flags = 0;
274    sigaction(SIGINT, &sigint_action, NULL);
275    sigaction(SIGTERM, &sigint_action, NULL);
276#elif defined (_WIN32)
277    auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
278        return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
279    };
280    SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
281#endif
282
283    if (is_router_server) {
284        LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
285        LOG_INF("%s: NOTE: router mode is experimental\n", __func__);
286        LOG_INF("%s:       it is not recommended to use this mode in untrusted environments\n", __func__);
287        if (ctx_http.thread.joinable()) {
288            ctx_http.thread.join(); // keep the main thread alive
289        }
290
291        // when the HTTP server stops, clean up and exit
292        clean_up();
293    } else {
294        LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
295        LOG_INF("%s: starting the main loop...\n", __func__);
296
297        // optionally, notify router server that this instance is ready
298        const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
299        std::thread monitor_thread;
300        if (router_port != nullptr) {
301            monitor_thread = server_models::setup_child_server(shutdown_handler);
302        }
303
304        // this call blocks the main thread until queue_tasks.terminate() is called
305        ctx_server.start_loop();
306
307        clean_up();
308        if (ctx_http.thread.joinable()) {
309            ctx_http.thread.join();
310        }
311        if (monitor_thread.joinable()) {
312            monitor_thread.join();
313        }
314
315        auto * ll_ctx = ctx_server.get_llama_context();
316        if (ll_ctx != nullptr) {
317            llama_memory_breakdown_print(ll_ctx);
318        }
319    }
320
321    return 0;
322}