1#include "server-common.h"
2#include "server-models.h"
3
4#include "preset.h"
5#include "download.h"
6
7#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
8#include <sheredom/subprocess.h>
9
10#include <functional>
11#include <algorithm>
12#include <thread>
13#include <mutex>
14#include <condition_variable>
15#include <cstring>
16#include <atomic>
17#include <chrono>
18#include <queue>
19#include <filesystem>
20#include <cstring>
21
22#ifdef _WIN32
23#include <winsock2.h>
24#include <windows.h>
25#else
26#include <sys/socket.h>
27#include <netinet/in.h>
28#include <arpa/inet.h>
29#include <unistd.h>
30extern char **environ;
31#endif
32
33#if defined(__APPLE__) && defined(__MACH__)
34// macOS: use _NSGetExecutablePath to get the executable path
35#include <mach-o/dyld.h>
36#include <limits.h>
37#endif
38
39#define DEFAULT_STOP_TIMEOUT 10 // seconds
40
41#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
42#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
43
44// address for child process, this is needed because router may run on 0.0.0.0
45// ref: https://github.com/ggml-org/llama.cpp/issues/17862
46#define CHILD_ADDR "127.0.0.1"
47
48static std::filesystem::path get_server_exec_path() {
49#if defined(_WIN32)
50 wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
51 DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
52 if (len == 0 || len >= _countof(buf)) {
53 throw std::runtime_error("GetModuleFileNameW failed or path too long");
54 }
55 return std::filesystem::path(buf);
56#elif defined(__APPLE__) && defined(__MACH__)
57 char small_path[PATH_MAX];
58 uint32_t size = sizeof(small_path);
59
60 if (_NSGetExecutablePath(small_path, &size) == 0) {
61 // resolve any symlinks to get absolute path
62 try {
63 return std::filesystem::canonical(std::filesystem::path(small_path));
64 } catch (...) {
65 return std::filesystem::path(small_path);
66 }
67 } else {
68 // buffer was too small, allocate required size and call again
69 std::vector<char> buf(size);
70 if (_NSGetExecutablePath(buf.data(), &size) == 0) {
71 try {
72 return std::filesystem::canonical(std::filesystem::path(buf.data()));
73 } catch (...) {
74 return std::filesystem::path(buf.data());
75 }
76 }
77 throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
78 }
79#else
80 char path[FILENAME_MAX];
81 ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
82 if (count <= 0) {
83 throw std::runtime_error("failed to resolve /proc/self/exe");
84 }
85 return std::filesystem::path(std::string(path, count));
86#endif
87}
88
89static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
90 preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
91 preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
92 preset.unset_option("LLAMA_API_KEY");
93 preset.unset_option("LLAMA_ARG_MODELS_DIR");
94 preset.unset_option("LLAMA_ARG_MODELS_MAX");
95 preset.unset_option("LLAMA_ARG_MODELS_PRESET");
96 preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
97 if (unset_model_args) {
98 preset.unset_option("LLAMA_ARG_MODEL");
99 preset.unset_option("LLAMA_ARG_MMPROJ");
100 preset.unset_option("LLAMA_ARG_HF_REPO");
101 }
102}
103
104#ifdef _WIN32
105static std::string wide_to_utf8(const wchar_t * ws) {
106 if (!ws || !*ws) {
107 return {};
108 }
109
110 const int len = static_cast<int>(std::wcslen(ws));
111 const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr);
112 if (bytes == 0) {
113 return {};
114 }
115
116 std::string utf8(bytes, '\0');
117 WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr);
118
119 return utf8;
120}
121#endif
122
123static std::vector<std::string> get_environment() {
124 std::vector<std::string> env;
125
126#ifdef _WIN32
127 LPWCH env_block = GetEnvironmentStringsW();
128 if (!env_block) {
129 return env;
130 }
131 for (LPWCH e = env_block; *e; e += wcslen(e) + 1) {
132 env.emplace_back(wide_to_utf8(e));
133 }
134 FreeEnvironmentStringsW(env_block);
135#else
136 if (environ == nullptr) {
137 return env;
138 }
139 for (char ** e = environ; *e != nullptr; e++) {
140 env.emplace_back(*e);
141 }
142#endif
143
144 return env;
145}
146
147void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
148 // update params
149 unset_reserved_args(preset, false);
150 preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
151 preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
152 preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
153 // TODO: maybe validate preset before rendering ?
154 // render args
155 args = preset.to_args(bin_path);
156}
157
158//
159// server_models
160//
161
162server_models::server_models(
163 const common_params & params,
164 int argc,
165 char ** argv)
166 : ctx_preset(LLAMA_EXAMPLE_SERVER),
167 base_params(params),
168 base_env(get_environment()),
169 base_preset(ctx_preset.load_from_args(argc, argv)) {
170 // clean up base preset
171 unset_reserved_args(base_preset, true);
172 // set binary path
173 try {
174 bin_path = get_server_exec_path().string();
175 } catch (const std::exception & e) {
176 bin_path = argv[0];
177 LOG_WRN("failed to get server executable path: %s\n", e.what());
178 LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
179 }
180 load_models();
181}
182
183void server_models::add_model(server_model_meta && meta) {
184 if (mapping.find(meta.name) != mapping.end()) {
185 throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
186 }
187 meta.update_args(ctx_preset, bin_path); // render args
188 std::string name = meta.name;
189 mapping[name] = instance_t{
190 /* subproc */ std::make_shared<subprocess_s>(),
191 /* th */ std::thread(),
192 /* meta */ std::move(meta)
193 };
194}
195
196// TODO: allow refreshing cached model list
197void server_models::load_models() {
198 // loading models from 3 sources:
199 // 1. cached models
200 common_presets cached_models = ctx_preset.load_from_cache();
201 SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
202 // 2. local models from --models-dir
203 common_presets local_models;
204 if (!base_params.models_dir.empty()) {
205 local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
206 SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
207 }
208 // 3. custom-path models from presets
209 common_preset global = {};
210 common_presets custom_presets = {};
211 if (!base_params.models_preset.empty()) {
212 custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
213 SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
214 }
215
216 // cascade, apply global preset first
217 cached_models = ctx_preset.cascade(global, cached_models);
218 local_models = ctx_preset.cascade(global, local_models);
219 custom_presets = ctx_preset.cascade(global, custom_presets);
220
221 // note: if a model exists in both cached and local, local takes precedence
222 common_presets final_presets;
223 for (const auto & [name, preset] : cached_models) {
224 final_presets[name] = preset;
225 }
226 for (const auto & [name, preset] : local_models) {
227 final_presets[name] = preset;
228 }
229
230 // process custom presets from INI
231 for (const auto & [name, custom] : custom_presets) {
232 if (final_presets.find(name) != final_presets.end()) {
233 // apply custom config if exists
234 common_preset & target = final_presets[name];
235 target.merge(custom);
236 } else {
237 // otherwise add directly
238 final_presets[name] = custom;
239 }
240 }
241
242 // server base preset from CLI args take highest precedence
243 for (auto & [name, preset] : final_presets) {
244 preset.merge(base_preset);
245 }
246
247 // convert presets to server_model_meta and add to mapping
248 for (const auto & preset : final_presets) {
249 server_model_meta meta{
250 /* preset */ preset.second,
251 /* name */ preset.first,
252 /* port */ 0,
253 /* status */ SERVER_MODEL_STATUS_UNLOADED,
254 /* last_used */ 0,
255 /* args */ std::vector<std::string>(),
256 /* exit_code */ 0,
257 /* stop_timeout */ DEFAULT_STOP_TIMEOUT,
258 };
259 add_model(std::move(meta));
260 }
261
262 // log available models
263 {
264 std::unordered_set<std::string> custom_names;
265 for (const auto & [name, preset] : custom_presets) {
266 custom_names.insert(name);
267 }
268 SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
269 for (const auto & [name, inst] : mapping) {
270 bool has_custom = custom_names.find(name) != custom_names.end();
271 SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
272 }
273 }
274
275 // handle custom stop-timeout option
276 for (auto & [name, inst] : mapping) {
277 std::string val;
278 if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
279 try {
280 inst.meta.stop_timeout = std::stoi(val);
281 } catch (...) {
282 SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
283 val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
284 inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
285 }
286 }
287 }
288
289 // load any autoload models
290 std::vector<std::string> models_to_load;
291 for (const auto & [name, inst] : mapping) {
292 std::string val;
293 if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
294 models_to_load.push_back(name);
295 }
296 }
297 if ((int)models_to_load.size() > base_params.models_max) {
298 throw std::runtime_error(string_format(
299 "number of models to load on startup (%zu) exceeds models_max (%d)",
300 models_to_load.size(),
301 base_params.models_max
302 ));
303 }
304 for (const auto & name : models_to_load) {
305 SRV_INF("(startup) loading model %s\n", name.c_str());
306 load(name);
307 }
308}
309
310void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
311 std::lock_guard<std::mutex> lk(mutex);
312 auto it = mapping.find(name);
313 if (it != mapping.end()) {
314 it->second.meta = meta;
315 }
316 cv.notify_all(); // notify wait_until_loaded
317}
318
319bool server_models::has_model(const std::string & name) {
320 std::lock_guard<std::mutex> lk(mutex);
321 return mapping.find(name) != mapping.end();
322}
323
324std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
325 std::lock_guard<std::mutex> lk(mutex);
326 auto it = mapping.find(name);
327 if (it != mapping.end()) {
328 return it->second.meta;
329 }
330 return std::nullopt;
331}
332
333static int get_free_port() {
334#ifdef _WIN32
335 WSADATA wsaData;
336 if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
337 return -1;
338 }
339 typedef SOCKET native_socket_t;
340#define INVALID_SOCKET_VAL INVALID_SOCKET
341#define CLOSE_SOCKET(s) closesocket(s)
342#else
343 typedef int native_socket_t;
344#define INVALID_SOCKET_VAL -1
345#define CLOSE_SOCKET(s) close(s)
346#endif
347
348 native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
349 if (sock == INVALID_SOCKET_VAL) {
350#ifdef _WIN32
351 WSACleanup();
352#endif
353 return -1;
354 }
355
356 struct sockaddr_in serv_addr;
357 std::memset(&serv_addr, 0, sizeof(serv_addr));
358 serv_addr.sin_family = AF_INET;
359 serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
360 serv_addr.sin_port = htons(0);
361
362 if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
363 CLOSE_SOCKET(sock);
364#ifdef _WIN32
365 WSACleanup();
366#endif
367 return -1;
368 }
369
370#ifdef _WIN32
371 int namelen = sizeof(serv_addr);
372#else
373 socklen_t namelen = sizeof(serv_addr);
374#endif
375 if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
376 CLOSE_SOCKET(sock);
377#ifdef _WIN32
378 WSACleanup();
379#endif
380 return -1;
381 }
382
383 int port = ntohs(serv_addr.sin_port);
384
385 CLOSE_SOCKET(sock);
386#ifdef _WIN32
387 WSACleanup();
388#endif
389
390 return port;
391}
392
393// helper to convert vector<string> to char **
394// pointers are only valid as long as the original vector is valid
395static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
396 std::vector<char *> result;
397 result.reserve(vec.size() + 1);
398 for (const auto & s : vec) {
399 result.push_back(const_cast<char*>(s.c_str()));
400 }
401 result.push_back(nullptr);
402 return result;
403}
404
405std::vector<server_model_meta> server_models::get_all_meta() {
406 std::lock_guard<std::mutex> lk(mutex);
407 std::vector<server_model_meta> result;
408 result.reserve(mapping.size());
409 for (const auto & [name, inst] : mapping) {
410 result.push_back(inst.meta);
411 }
412 return result;
413}
414
415void server_models::unload_lru() {
416 if (base_params.models_max <= 0) {
417 return; // no limit
418 }
419 // remove one of the servers if we passed the models_max (least recently used - LRU)
420 std::string lru_model_name = "";
421 int64_t lru_last_used = ggml_time_ms();
422 size_t count_active = 0;
423 {
424 std::unique_lock<std::mutex> lk(mutex);
425 for (const auto & m : mapping) {
426 if (m.second.meta.is_active()) {
427 count_active++;
428 if (m.second.meta.last_used < lru_last_used) {
429 lru_model_name = m.first;
430 lru_last_used = m.second.meta.last_used;
431 }
432 }
433 }
434 }
435 if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
436 SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
437 unload(lru_model_name);
438 // wait for unload to complete
439 {
440 std::unique_lock<std::mutex> lk(mutex);
441 cv.wait(lk, [this, &lru_model_name]() {
442 return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
443 });
444 }
445 }
446}
447
448void server_models::load(const std::string & name) {
449 if (!has_model(name)) {
450 throw std::runtime_error("model name=" + name + " is not found");
451 }
452 unload_lru();
453
454 std::lock_guard<std::mutex> lk(mutex);
455
456 auto meta = mapping[name].meta;
457 if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
458 SRV_INF("model %s is not ready\n", name.c_str());
459 return;
460 }
461
462 // prepare new instance info
463 instance_t inst;
464 inst.meta = meta;
465 inst.meta.port = get_free_port();
466 inst.meta.status = SERVER_MODEL_STATUS_LOADING;
467 inst.meta.last_used = ggml_time_ms();
468
469 if (inst.meta.port <= 0) {
470 throw std::runtime_error("failed to get a port number");
471 }
472
473 inst.subproc = std::make_shared<subprocess_s>();
474 {
475 SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
476
477 inst.meta.update_args(ctx_preset, bin_path); // render args
478
479 std::vector<std::string> child_args = inst.meta.args; // copy
480 std::vector<std::string> child_env = base_env; // copy
481 child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
482
483 SRV_INF("%s", "spawning server instance with args:\n");
484 for (const auto & arg : child_args) {
485 SRV_INF(" %s\n", arg.c_str());
486 }
487 inst.meta.args = child_args; // save for debugging
488
489 std::vector<char *> argv = to_char_ptr_array(child_args);
490 std::vector<char *> envp = to_char_ptr_array(child_env);
491
492 // TODO @ngxson : maybe separate stdout and stderr in the future
493 // so that we can use stdout for commands and stderr for logging
494 int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
495 int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
496 if (result != 0) {
497 throw std::runtime_error("failed to spawn server instance");
498 }
499
500 inst.stdin_file = subprocess_stdin(inst.subproc.get());
501 }
502
503 // start a thread to manage the child process
504 // captured variables are guaranteed to be destroyed only after the thread is joined
505 inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
506 FILE * stdin_file = subprocess_stdin(child_proc.get());
507 FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
508
509 std::thread log_thread([&]() {
510 // read stdout/stderr and forward to main server log
511 // also handle status report from child process
512 bool state_received = false; // true if child state received
513 if (stdout_file) {
514 char buffer[4096];
515 while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
516 LOG("[%5d] %s", port, buffer);
517 if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
518 // child process is ready
519 this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
520 state_received = true;
521 }
522 }
523 } else {
524 SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
525 }
526 });
527
528 std::thread stopping_thread([&]() {
529 // thread to monitor stopping signal
530 auto is_stopping = [this, &name]() {
531 return this->stopping_models.find(name) != this->stopping_models.end();
532 };
533 {
534 std::unique_lock<std::mutex> lk(this->mutex);
535 this->cv_stop.wait(lk, is_stopping);
536 }
537 SRV_INF("stopping model instance name=%s\n", name.c_str());
538 // send interrupt to child process
539 fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
540 fflush(stdin_file);
541 // wait to stop gracefully or timeout
542 int64_t start_time = ggml_time_ms();
543 while (true) {
544 std::unique_lock<std::mutex> lk(this->mutex);
545 if (!is_stopping()) {
546 return; // already stopped
547 }
548 int64_t elapsed = ggml_time_ms() - start_time;
549 if (elapsed >= stop_timeout * 1000) {
550 // timeout, force kill
551 SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
552 subprocess_terminate(child_proc.get());
553 return;
554 }
555 this->cv_stop.wait_for(lk, std::chrono::seconds(1));
556 }
557 });
558
559 // we reach here when the child process exits
560 // note: we cannot join() prior to this point because it will close stdin_file
561 if (log_thread.joinable()) {
562 log_thread.join();
563 }
564
565 // stop the timeout monitoring thread
566 {
567 std::lock_guard<std::mutex> lk(this->mutex);
568 stopping_models.erase(name);
569 cv_stop.notify_all();
570 }
571 if (stopping_thread.joinable()) {
572 stopping_thread.join();
573 }
574
575 // get the exit code
576 int exit_code = 0;
577 subprocess_join(child_proc.get(), &exit_code);
578 subprocess_destroy(child_proc.get());
579
580 // update status and exit code
581 this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
582 SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
583 });
584
585 // clean up old process/thread if exists
586 {
587 auto & old_instance = mapping[name];
588 // old process should have exited already, but just in case, we clean it up here
589 if (subprocess_alive(old_instance.subproc.get())) {
590 SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
591 subprocess_terminate(old_instance.subproc.get()); // force kill
592 }
593 if (old_instance.th.joinable()) {
594 old_instance.th.join();
595 }
596 }
597
598 mapping[name] = std::move(inst);
599 cv.notify_all();
600}
601
602void server_models::unload(const std::string & name) {
603 std::lock_guard<std::mutex> lk(mutex);
604 auto it = mapping.find(name);
605 if (it != mapping.end()) {
606 if (it->second.meta.is_active()) {
607 SRV_INF("unloading model instance name=%s\n", name.c_str());
608 stopping_models.insert(name);
609 cv_stop.notify_all();
610 // status change will be handled by the managing thread
611 } else {
612 SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
613 }
614 }
615}
616
617void server_models::unload_all() {
618 std::vector<std::thread> to_join;
619 {
620 std::lock_guard<std::mutex> lk(mutex);
621 for (auto & [name, inst] : mapping) {
622 if (inst.meta.is_active()) {
623 SRV_INF("unloading model instance name=%s\n", name.c_str());
624 stopping_models.insert(name);
625 cv_stop.notify_all();
626 // status change will be handled by the managing thread
627 }
628 // moving the thread to join list to avoid deadlock
629 to_join.push_back(std::move(inst.th));
630 }
631 }
632 for (auto & th : to_join) {
633 if (th.joinable()) {
634 th.join();
635 }
636 }
637}
638
639void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
640 std::unique_lock<std::mutex> lk(mutex);
641 auto it = mapping.find(name);
642 if (it != mapping.end()) {
643 auto & meta = it->second.meta;
644 meta.status = status;
645 meta.exit_code = exit_code;
646 }
647 cv.notify_all();
648}
649
650void server_models::wait_until_loaded(const std::string & name) {
651 std::unique_lock<std::mutex> lk(mutex);
652 cv.wait(lk, [this, &name]() {
653 auto it = mapping.find(name);
654 if (it != mapping.end()) {
655 return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
656 }
657 return false;
658 });
659}
660
661bool server_models::ensure_model_loaded(const std::string & name) {
662 auto meta = get_meta(name);
663 if (!meta.has_value()) {
664 throw std::runtime_error("model name=" + name + " is not found");
665 }
666 if (meta->status == SERVER_MODEL_STATUS_LOADED) {
667 return false; // already loaded
668 }
669 if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
670 SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
671 load(name);
672 }
673
674 // for loading state
675 SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
676 wait_until_loaded(name);
677
678 // check final status
679 meta = get_meta(name);
680 if (!meta.has_value() || meta->is_failed()) {
681 throw std::runtime_error("model name=" + name + " failed to load");
682 }
683
684 return true;
685}
686
687server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
688 auto meta = get_meta(name);
689 if (!meta.has_value()) {
690 throw std::runtime_error("model name=" + name + " is not found");
691 }
692 if (meta->status != SERVER_MODEL_STATUS_LOADED) {
693 throw std::invalid_argument("model name=" + name + " is not loaded");
694 }
695 if (update_last_used) {
696 std::unique_lock<std::mutex> lk(mutex);
697 mapping[name].meta.last_used = ggml_time_ms();
698 }
699 SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
700 auto proxy = std::make_unique<server_http_proxy>(
701 method,
702 CHILD_ADDR,
703 meta->port,
704 req.path,
705 req.headers,
706 req.body,
707 req.should_stop,
708 base_params.timeout_read,
709 base_params.timeout_write
710 );
711 return proxy;
712}
713
714std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
715 // send a notification to the router server that a model instance is ready
716 common_log_pause(common_log_main());
717 fflush(stdout);
718 fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
719 fflush(stdout);
720 common_log_resume(common_log_main());
721
722 // setup thread for monitoring stdin
723 return std::thread([shutdown_handler]() {
724 // wait for EOF on stdin
725 SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
726 bool eof = false;
727 while (true) {
728 std::string line;
729 if (!std::getline(std::cin, line)) {
730 // EOF detected, that means the router server is unexpectedly exit or killed
731 eof = true;
732 break;
733 }
734 if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
735 SRV_INF("%s", "exit command received, exiting...\n");
736 shutdown_handler(0);
737 break;
738 }
739 }
740 if (eof) {
741 SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
742 exit(1);
743 }
744 });
745}
746
747
748
749//
750// server_models_routes
751//
752
753static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
754 res->status = 200;
755 res->data = safe_json_to_str(response_data);
756}
757
758static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
759 res->status = json_value(error_data, "code", 500);
760 res->data = safe_json_to_str({{ "error", error_data }});
761}
762
763static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
764 if (name.empty()) {
765 res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
766 return false;
767 }
768 auto meta = models.get_meta(name);
769 if (!meta.has_value()) {
770 res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST));
771 return false;
772 }
773 if (models_autoload) {
774 models.ensure_model_loaded(name);
775 } else {
776 if (meta->status != SERVER_MODEL_STATUS_LOADED) {
777 res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
778 return false;
779 }
780 }
781 return true;
782}
783
784static bool is_autoload(const common_params & params, const server_http_req & req) {
785 std::string autoload = req.get_param("autoload");
786 if (autoload.empty()) {
787 return params.models_autoload;
788 } else {
789 return autoload == "true" || autoload == "1";
790 }
791}
792
793void server_models_routes::init_routes() {
794 this->get_router_props = [this](const server_http_req & req) {
795 std::string name = req.get_param("model");
796 if (name.empty()) {
797 // main instance
798 auto res = std::make_unique<server_http_res>();
799 res_ok(res, {
800 // TODO: add support for this on web UI
801 {"role", "router"},
802 {"max_instances", 4}, // dummy value for testing
803 // this is a dummy response to make sure webui doesn't break
804 {"model_alias", "llama-server"},
805 {"model_path", "none"},
806 {"default_generation_settings", {
807 {"params", json{}},
808 {"n_ctx", 0},
809 }},
810 {"webui_settings", webui_settings},
811 });
812 return res;
813 }
814 return proxy_get(req);
815 };
816
817 this->proxy_get = [this](const server_http_req & req) {
818 std::string method = "GET";
819 std::string name = req.get_param("model");
820 bool autoload = is_autoload(params, req);
821 auto error_res = std::make_unique<server_http_res>();
822 if (!router_validate_model(name, models, autoload, error_res)) {
823 return error_res;
824 }
825 return models.proxy_request(req, method, name, false);
826 };
827
828 this->proxy_post = [this](const server_http_req & req) {
829 std::string method = "POST";
830 json body = json::parse(req.body);
831 std::string name = json_value(body, "model", std::string());
832 bool autoload = is_autoload(params, req);
833 auto error_res = std::make_unique<server_http_res>();
834 if (!router_validate_model(name, models, autoload, error_res)) {
835 return error_res;
836 }
837 return models.proxy_request(req, method, name, true); // update last usage for POST request only
838 };
839
840 this->post_router_models_load = [this](const server_http_req & req) {
841 auto res = std::make_unique<server_http_res>();
842 json body = json::parse(req.body);
843 std::string name = json_value(body, "model", std::string());
844 auto model = models.get_meta(name);
845 if (!model.has_value()) {
846 res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
847 return res;
848 }
849 if (model->status == SERVER_MODEL_STATUS_LOADED) {
850 res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
851 return res;
852 }
853 models.load(name);
854 res_ok(res, {{"success", true}});
855 return res;
856 };
857
858 this->get_router_models = [this](const server_http_req &) {
859 auto res = std::make_unique<server_http_res>();
860 json models_json = json::array();
861 auto all_models = models.get_all_meta();
862 std::time_t t = std::time(0);
863 for (const auto & meta : all_models) {
864 json status {
865 {"value", server_model_status_to_string(meta.status)},
866 {"args", meta.args},
867 };
868 if (!meta.preset.name.empty()) {
869 common_preset preset_copy = meta.preset;
870 unset_reserved_args(preset_copy, false);
871 preset_copy.unset_option("LLAMA_ARG_HOST");
872 preset_copy.unset_option("LLAMA_ARG_PORT");
873 preset_copy.unset_option("LLAMA_ARG_ALIAS");
874 status["preset"] = preset_copy.to_ini();
875 }
876 if (meta.is_failed()) {
877 status["exit_code"] = meta.exit_code;
878 status["failed"] = true;
879 }
880 models_json.push_back(json {
881 {"id", meta.name},
882 {"object", "model"}, // for OAI-compat
883 {"owned_by", "llamacpp"}, // for OAI-compat
884 {"created", t}, // for OAI-compat
885 {"status", status},
886 // TODO: add other fields, may require reading GGUF metadata
887 });
888 }
889 res_ok(res, {
890 {"data", models_json},
891 {"object", "list"},
892 });
893 return res;
894 };
895
896 this->post_router_models_unload = [this](const server_http_req & req) {
897 auto res = std::make_unique<server_http_res>();
898 json body = json::parse(req.body);
899 std::string name = json_value(body, "model", std::string());
900 auto model = models.get_meta(name);
901 if (!model.has_value()) {
902 res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
903 return res;
904 }
905 if (!model->is_active()) {
906 res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
907 return res;
908 }
909 models.unload(name);
910 res_ok(res, {{"success", true}});
911 return res;
912 };
913}
914
915
916
917//
918// server_http_proxy
919//
920
921// simple implementation of a pipe
922// used for streaming data between threads
923template<typename T>
924struct pipe_t {
925 std::mutex mutex;
926 std::condition_variable cv;
927 std::queue<T> queue;
928 std::atomic<bool> writer_closed{false};
929 std::atomic<bool> reader_closed{false};
930 void close_write() {
931 writer_closed.store(true, std::memory_order_relaxed);
932 cv.notify_all();
933 }
934 void close_read() {
935 reader_closed.store(true, std::memory_order_relaxed);
936 cv.notify_all();
937 }
938 bool read(T & output, const std::function<bool()> & should_stop) {
939 std::unique_lock<std::mutex> lk(mutex);
940 constexpr auto poll_interval = std::chrono::milliseconds(500);
941 while (true) {
942 if (!queue.empty()) {
943 output = std::move(queue.front());
944 queue.pop();
945 return true;
946 }
947 if (writer_closed.load()) {
948 return false; // clean EOF
949 }
950 if (should_stop()) {
951 close_read(); // signal broken pipe to writer
952 return false; // cancelled / reader no longer alive
953 }
954 cv.wait_for(lk, poll_interval);
955 }
956 }
957 bool write(T && data) {
958 std::lock_guard<std::mutex> lk(mutex);
959 if (reader_closed.load()) {
960 return false; // broken pipe
961 }
962 queue.push(std::move(data));
963 cv.notify_one();
964 return true;
965 }
966};
967
968static std::string to_lower_copy(const std::string & value) {
969 std::string lowered(value.size(), '\0');
970 std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
971 return lowered;
972}
973
974static bool should_strip_proxy_header(const std::string & header_name) {
975 // Headers that get duplicated when router forwards child responses
976 if (header_name == "server" ||
977 header_name == "transfer-encoding" ||
978 header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
979 header_name == "keep-alive") {
980 return true;
981 }
982
983 // Router injects CORS, child also sends them: duplicate
984 if (header_name.rfind("access-control-", 0) == 0) {
985 return true;
986 }
987
988 return false;
989}
990
991server_http_proxy::server_http_proxy(
992 const std::string & method,
993 const std::string & host,
994 int port,
995 const std::string & path,
996 const std::map<std::string, std::string> & headers,
997 const std::string & body,
998 const std::function<bool()> should_stop,
999 int32_t timeout_read,
1000 int32_t timeout_write
1001 ) {
1002 // shared between reader and writer threads
1003 auto cli = std::make_shared<httplib::Client>(host, port);
1004 auto pipe = std::make_shared<pipe_t<msg_t>>();
1005
1006 // setup Client
1007 cli->set_connection_timeout(0, 200000); // 200 milliseconds
1008 cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
1009 cli->set_read_timeout(timeout_write, 0);
1010 this->status = 500; // to be overwritten upon response
1011 this->cleanup = [pipe]() {
1012 pipe->close_read();
1013 pipe->close_write();
1014 };
1015
1016 // wire up the receive end of the pipe
1017 this->next = [pipe, should_stop](std::string & out) -> bool {
1018 msg_t msg;
1019 bool has_next = pipe->read(msg, should_stop);
1020 if (!msg.data.empty()) {
1021 out = std::move(msg.data);
1022 }
1023 return has_next; // false if EOF or pipe broken
1024 };
1025
1026 // wire up the HTTP client
1027 // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
1028 httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
1029 msg_t msg;
1030 msg.status = response.status;
1031 for (const auto & [key, value] : response.headers) {
1032 const auto lowered = to_lower_copy(key);
1033 if (should_strip_proxy_header(lowered)) {
1034 continue;
1035 }
1036 if (lowered == "content-type") {
1037 msg.content_type = value;
1038 continue;
1039 }
1040 msg.headers[key] = value;
1041 }
1042 return pipe->write(std::move(msg)); // send headers first
1043 };
1044 httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
1045 // send data chunks
1046 // returns false if pipe is closed / broken (signal to stop receiving)
1047 return pipe->write({{}, 0, std::string(data, data_length), ""});
1048 };
1049
1050 // prepare the request to destination server
1051 httplib::Request req;
1052 {
1053 req.method = method;
1054 req.path = path;
1055 for (const auto & [key, value] : headers) {
1056 req.set_header(key, value);
1057 }
1058 req.body = body;
1059 req.response_handler = response_handler;
1060 req.content_receiver = content_receiver;
1061 }
1062
1063 // start the proxy thread
1064 SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
1065 this->thread = std::thread([cli, pipe, req]() {
1066 auto result = cli->send(std::move(req));
1067 if (result.error() != httplib::Error::Success) {
1068 auto err_str = httplib::to_string(result.error());
1069 SRV_ERR("http client error: %s\n", err_str.c_str());
1070 pipe->write({{}, 500, "", ""}); // header
1071 pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
1072 }
1073 pipe->close_write(); // signal EOF to reader
1074 SRV_DBG("%s", "client request thread ended\n");
1075 });
1076 this->thread.detach();
1077
1078 // wait for the first chunk (headers)
1079 {
1080 msg_t header;
1081 if (pipe->read(header, should_stop)) {
1082 SRV_DBG("%s", "received response headers\n");
1083 this->status = header.status;
1084 this->headers = std::move(header.headers);
1085 if (!header.content_type.empty()) {
1086 this->content_type = std::move(header.content_type);
1087 }
1088 } else {
1089 SRV_DBG("%s", "no response headers received (request cancelled?)\n");
1090 }
1091 }
1092}