diff options
Diffstat (limited to 'llama.cpp/tools/server/server-models.cpp')
| -rw-r--r-- | llama.cpp/tools/server/server-models.cpp | 1092 |
1 files changed, 1092 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/server-models.cpp b/llama.cpp/tools/server/server-models.cpp new file mode 100644 index 0000000..5765547 --- /dev/null +++ b/llama.cpp/tools/server/server-models.cpp @@ -0,0 +1,1092 @@ +#include "server-common.h" +#include "server-models.h" + +#include "preset.h" +#include "download.h" + +#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h +#include <sheredom/subprocess.h> + +#include <functional> +#include <algorithm> +#include <thread> +#include <mutex> +#include <condition_variable> +#include <cstring> +#include <atomic> +#include <chrono> +#include <queue> +#include <filesystem> +#include <cstring> + +#ifdef _WIN32 +#include <winsock2.h> +#include <windows.h> +#else +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <unistd.h> +extern char **environ; +#endif + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include <mach-o/dyld.h> +#include <limits.h> +#endif + +#define DEFAULT_STOP_TIMEOUT 10 // seconds + +#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" +#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" + +// address for child process, this is needed because router may run on 0.0.0.0 +// ref: https://github.com/ggml-org/llama.cpp/issues/17862 +#define CHILD_ADDR "127.0.0.1" + +static std::filesystem::path get_server_exec_path() { +#if defined(_WIN32) + wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths + DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); + if (len == 0 || len >= _countof(buf)) { + throw std::runtime_error("GetModuleFileNameW failed or path too long"); + } + return std::filesystem::path(buf); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector<char> buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); +#endif +} + +static void unset_reserved_args(common_preset & preset, bool unset_model_args) { + preset.unset_option("LLAMA_ARG_SSL_KEY_FILE"); + preset.unset_option("LLAMA_ARG_SSL_CERT_FILE"); + preset.unset_option("LLAMA_API_KEY"); + preset.unset_option("LLAMA_ARG_MODELS_DIR"); + preset.unset_option("LLAMA_ARG_MODELS_MAX"); + preset.unset_option("LLAMA_ARG_MODELS_PRESET"); + preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD"); + if (unset_model_args) { + preset.unset_option("LLAMA_ARG_MODEL"); + preset.unset_option("LLAMA_ARG_MMPROJ"); + preset.unset_option("LLAMA_ARG_HF_REPO"); + } +} + +#ifdef _WIN32 +static std::string wide_to_utf8(const wchar_t * ws) { + if (!ws || !*ws) { + return {}; + } + + const int len = static_cast<int>(std::wcslen(ws)); + const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr); + if (bytes == 0) { + return {}; + } + + std::string utf8(bytes, '\0'); + WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr); + + return utf8; +} +#endif + +static std::vector<std::string> get_environment() { + std::vector<std::string> env; + +#ifdef _WIN32 + LPWCH env_block = GetEnvironmentStringsW(); + if (!env_block) { + return env; + } + for (LPWCH e = env_block; *e; e += wcslen(e) + 1) { + env.emplace_back(wide_to_utf8(e)); + } + FreeEnvironmentStringsW(env_block); +#else + if (environ == nullptr) { + return env; + } + for (char ** e = environ; *e != nullptr; e++) { + env.emplace_back(*e); + } +#endif + + return env; +} + +void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) { + // update params + unset_reserved_args(preset, false); + preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR); + preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port)); + preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name); + // TODO: maybe validate preset before rendering ? + // render args + args = preset.to_args(bin_path); +} + +// +// server_models +// + +server_models::server_models( + const common_params & params, + int argc, + char ** argv) + : ctx_preset(LLAMA_EXAMPLE_SERVER), + base_params(params), + base_env(get_environment()), + base_preset(ctx_preset.load_from_args(argc, argv)) { + // clean up base preset + unset_reserved_args(base_preset, true); + // set binary path + try { + bin_path = get_server_exec_path().string(); + } catch (const std::exception & e) { + bin_path = argv[0]; + LOG_WRN("failed to get server executable path: %s\n", e.what()); + LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]); + } + load_models(); +} + +void server_models::add_model(server_model_meta && meta) { + if (mapping.find(meta.name) != mapping.end()) { + throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); + } + meta.update_args(ctx_preset, bin_path); // render args + std::string name = meta.name; + mapping[name] = instance_t{ + /* subproc */ std::make_shared<subprocess_s>(), + /* th */ std::thread(), + /* meta */ std::move(meta) + }; +} + +// TODO: allow refreshing cached model list +void server_models::load_models() { + // loading models from 3 sources: + // 1. cached models + common_presets cached_models = ctx_preset.load_from_cache(); + SRV_INF("Loaded %zu cached model presets\n", cached_models.size()); + // 2. local models from --models-dir + common_presets local_models; + if (!base_params.models_dir.empty()) { + local_models = ctx_preset.load_from_models_dir(base_params.models_dir); + SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str()); + } + // 3. custom-path models from presets + common_preset global = {}; + common_presets custom_presets = {}; + if (!base_params.models_preset.empty()) { + custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global); + SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str()); + } + + // cascade, apply global preset first + cached_models = ctx_preset.cascade(global, cached_models); + local_models = ctx_preset.cascade(global, local_models); + custom_presets = ctx_preset.cascade(global, custom_presets); + + // note: if a model exists in both cached and local, local takes precedence + common_presets final_presets; + for (const auto & [name, preset] : cached_models) { + final_presets[name] = preset; + } + for (const auto & [name, preset] : local_models) { + final_presets[name] = preset; + } + + // process custom presets from INI + for (const auto & [name, custom] : custom_presets) { + if (final_presets.find(name) != final_presets.end()) { + // apply custom config if exists + common_preset & target = final_presets[name]; + target.merge(custom); + } else { + // otherwise add directly + final_presets[name] = custom; + } + } + + // server base preset from CLI args take highest precedence + for (auto & [name, preset] : final_presets) { + preset.merge(base_preset); + } + + // convert presets to server_model_meta and add to mapping + for (const auto & preset : final_presets) { + server_model_meta meta{ + /* preset */ preset.second, + /* name */ preset.first, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector<std::string>(), + /* exit_code */ 0, + /* stop_timeout */ DEFAULT_STOP_TIMEOUT, + }; + add_model(std::move(meta)); + } + + // log available models + { + std::unordered_set<std::string> custom_names; + for (const auto & [name, preset] : custom_presets) { + custom_names.insert(name); + } + SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); + for (const auto & [name, inst] : mapping) { + bool has_custom = custom_names.find(name) != custom_names.end(); + SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str()); + } + } + + // handle custom stop-timeout option + for (auto & [name, inst] : mapping) { + std::string val; + if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) { + try { + inst.meta.stop_timeout = std::stoi(val); + } catch (...) { + SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n", + val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT); + inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT; + } + } + } + + // load any autoload models + std::vector<std::string> models_to_load; + for (const auto & [name, inst] : mapping) { + std::string val; + if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) { + models_to_load.push_back(name); + } + } + if ((int)models_to_load.size() > base_params.models_max) { + throw std::runtime_error(string_format( + "number of models to load on startup (%zu) exceeds models_max (%d)", + models_to_load.size(), + base_params.models_max + )); + } + for (const auto & name : models_to_load) { + SRV_INF("(startup) loading model %s\n", name.c_str()); + load(name); + } +} + +void server_models::update_meta(const std::string & name, const server_model_meta & meta) { + std::lock_guard<std::mutex> lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + it->second.meta = meta; + } + cv.notify_all(); // notify wait_until_loaded +} + +bool server_models::has_model(const std::string & name) { + std::lock_guard<std::mutex> lk(mutex); + return mapping.find(name) != mapping.end(); +} + +std::optional<server_model_meta> server_models::get_meta(const std::string & name) { + std::lock_guard<std::mutex> lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; +} + +static int get_free_port() { +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + return -1; + } + typedef SOCKET native_socket_t; +#define INVALID_SOCKET_VAL INVALID_SOCKET +#define CLOSE_SOCKET(s) closesocket(s) +#else + typedef int native_socket_t; +#define INVALID_SOCKET_VAL -1 +#define CLOSE_SOCKET(s) close(s) +#endif + + native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET_VAL) { +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + struct sockaddr_in serv_addr; + std::memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(0); + + if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + +#ifdef _WIN32 + int namelen = sizeof(serv_addr); +#else + socklen_t namelen = sizeof(serv_addr); +#endif + if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + int port = ntohs(serv_addr.sin_port); + + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + + return port; +} + +// helper to convert vector<string> to char ** +// pointers are only valid as long as the original vector is valid +static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) { + std::vector<char *> result; + result.reserve(vec.size() + 1); + for (const auto & s : vec) { + result.push_back(const_cast<char*>(s.c_str())); + } + result.push_back(nullptr); + return result; +} + +std::vector<server_model_meta> server_models::get_all_meta() { + std::lock_guard<std::mutex> lk(mutex); + std::vector<server_model_meta> result; + result.reserve(mapping.size()); + for (const auto & [name, inst] : mapping) { + result.push_back(inst.meta); + } + return result; +} + +void server_models::unload_lru() { + if (base_params.models_max <= 0) { + return; // no limit + } + // remove one of the servers if we passed the models_max (least recently used - LRU) + std::string lru_model_name = ""; + int64_t lru_last_used = ggml_time_ms(); + size_t count_active = 0; + { + std::unique_lock<std::mutex> lk(mutex); + for (const auto & m : mapping) { + if (m.second.meta.is_active()) { + count_active++; + if (m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + } + if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { + SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + unload(lru_model_name); + // wait for unload to complete + { + std::unique_lock<std::mutex> lk(mutex); + cv.wait(lk, [this, &lru_model_name]() { + return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; + }); + } + } +} + +void server_models::load(const std::string & name) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); + + std::lock_guard<std::mutex> lk(mutex); + + auto meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model %s is not ready\n", name.c_str()); + return; + } + + // prepare new instance info + instance_t inst; + inst.meta = meta; + inst.meta.port = get_free_port(); + inst.meta.status = SERVER_MODEL_STATUS_LOADING; + inst.meta.last_used = ggml_time_ms(); + + if (inst.meta.port <= 0) { + throw std::runtime_error("failed to get a port number"); + } + + inst.subproc = std::make_shared<subprocess_s>(); + { + SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); + + inst.meta.update_args(ctx_preset, bin_path); // render args + + std::vector<std::string> child_args = inst.meta.args; // copy + std::vector<std::string> child_env = base_env; // copy + child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + + SRV_INF("%s", "spawning server instance with args:\n"); + for (const auto & arg : child_args) { + SRV_INF(" %s\n", arg.c_str()); + } + inst.meta.args = child_args; // save for debugging + + std::vector<char *> argv = to_char_ptr_array(child_args); + std::vector<char *> envp = to_char_ptr_array(child_env); + + // TODO @ngxson : maybe separate stdout and stderr in the future + // so that we can use stdout for commands and stderr for logging + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get()); + if (result != 0) { + throw std::runtime_error("failed to spawn server instance"); + } + + inst.stdin_file = subprocess_stdin(inst.subproc.get()); + } + + // start a thread to manage the child process + // captured variables are guaranteed to be destroyed only after the thread is joined + inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { + FILE * stdin_file = subprocess_stdin(child_proc.get()); + FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr + + std::thread log_thread([&]() { + // read stdout/stderr and forward to main server log + // also handle status report from child process + bool state_received = false; // true if child state received + if (stdout_file) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { + LOG("[%5d] %s", port, buffer); + if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { + // child process is ready + this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); + state_received = true; + } + } + } else { + SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); + } + }); + + std::thread stopping_thread([&]() { + // thread to monitor stopping signal + auto is_stopping = [this, &name]() { + return this->stopping_models.find(name) != this->stopping_models.end(); + }; + { + std::unique_lock<std::mutex> lk(this->mutex); + this->cv_stop.wait(lk, is_stopping); + } + SRV_INF("stopping model instance name=%s\n", name.c_str()); + // send interrupt to child process + fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); + fflush(stdin_file); + // wait to stop gracefully or timeout + int64_t start_time = ggml_time_ms(); + while (true) { + std::unique_lock<std::mutex> lk(this->mutex); + if (!is_stopping()) { + return; // already stopped + } + int64_t elapsed = ggml_time_ms() - start_time; + if (elapsed >= stop_timeout * 1000) { + // timeout, force kill + SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout); + subprocess_terminate(child_proc.get()); + return; + } + this->cv_stop.wait_for(lk, std::chrono::seconds(1)); + } + }); + + // we reach here when the child process exits + // note: we cannot join() prior to this point because it will close stdin_file + if (log_thread.joinable()) { + log_thread.join(); + } + + // stop the timeout monitoring thread + { + std::lock_guard<std::mutex> lk(this->mutex); + stopping_models.erase(name); + cv_stop.notify_all(); + } + if (stopping_thread.joinable()) { + stopping_thread.join(); + } + + // get the exit code + int exit_code = 0; + subprocess_join(child_proc.get(), &exit_code); + subprocess_destroy(child_proc.get()); + + // update status and exit code + this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); + SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); + }); + + // clean up old process/thread if exists + { + auto & old_instance = mapping[name]; + // old process should have exited already, but just in case, we clean it up here + if (subprocess_alive(old_instance.subproc.get())) { + SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); + subprocess_terminate(old_instance.subproc.get()); // force kill + } + if (old_instance.th.joinable()) { + old_instance.th.join(); + } + } + + mapping[name] = std::move(inst); + cv.notify_all(); +} + +void server_models::unload(const std::string & name) { + std::lock_guard<std::mutex> lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + if (it->second.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + stopping_models.insert(name); + cv_stop.notify_all(); + // status change will be handled by the managing thread + } else { + SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); + } + } +} + +void server_models::unload_all() { + std::vector<std::thread> to_join; + { + std::lock_guard<std::mutex> lk(mutex); + for (auto & [name, inst] : mapping) { + if (inst.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + stopping_models.insert(name); + cv_stop.notify_all(); + // status change will be handled by the managing thread + } + // moving the thread to join list to avoid deadlock + to_join.push_back(std::move(inst.th)); + } + } + for (auto & th : to_join) { + if (th.joinable()) { + th.join(); + } + } +} + +void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { + std::unique_lock<std::mutex> lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + auto & meta = it->second.meta; + meta.status = status; + meta.exit_code = exit_code; + } + cv.notify_all(); +} + +void server_models::wait_until_loaded(const std::string & name) { + std::unique_lock<std::mutex> lk(mutex); + cv.wait(lk, [this, &name]() { + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + } + return false; + }); +} + +bool server_models::ensure_model_loaded(const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status == SERVER_MODEL_STATUS_LOADED) { + return false; // already loaded + } + if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); + load(name); + } + + // for loading state + SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); + wait_until_loaded(name); + + // check final status + meta = get_meta(name); + if (!meta.has_value() || meta->is_failed()) { + throw std::runtime_error("model name=" + name + " failed to load"); + } + + return true; +} + +server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + throw std::invalid_argument("model name=" + name + " is not loaded"); + } + if (update_last_used) { + std::unique_lock<std::mutex> lk(mutex); + mapping[name].meta.last_used = ggml_time_ms(); + } + SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); + auto proxy = std::make_unique<server_http_proxy>( + method, + CHILD_ADDR, + meta->port, + req.path, + req.headers, + req.body, + req.should_stop, + base_params.timeout_read, + base_params.timeout_write + ); + return proxy; +} + +std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) { + // send a notification to the router server that a model instance is ready + common_log_pause(common_log_main()); + fflush(stdout); + fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY); + fflush(stdout); + common_log_resume(common_log_main()); + + // setup thread for monitoring stdin + return std::thread([shutdown_handler]() { + // wait for EOF on stdin + SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n"); + bool eof = false; + while (true) { + std::string line; + if (!std::getline(std::cin, line)) { + // EOF detected, that means the router server is unexpectedly exit or killed + eof = true; + break; + } + if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) { + SRV_INF("%s", "exit command received, exiting...\n"); + shutdown_handler(0); + break; + } + } + if (eof) { + SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n"); + exit(1); + } + }); +} + + + +// +// server_models_routes +// + +static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) { + res->status = 200; + res->data = safe_json_to_str(response_data); +} + +static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) { + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); +} + +static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) { + if (name.empty()) { + res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + auto meta = models.get_meta(name); + if (!meta.has_value()) { + res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST)); + return false; + } + if (models_autoload) { + models.ensure_model_loaded(name); + } else { + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + } + return true; +} + +static bool is_autoload(const common_params & params, const server_http_req & req) { + std::string autoload = req.get_param("autoload"); + if (autoload.empty()) { + return params.models_autoload; + } else { + return autoload == "true" || autoload == "1"; + } +} + +void server_models_routes::init_routes() { + this->get_router_props = [this](const server_http_req & req) { + std::string name = req.get_param("model"); + if (name.empty()) { + // main instance + auto res = std::make_unique<server_http_res>(); + res_ok(res, { + // TODO: add support for this on web UI + {"role", "router"}, + {"max_instances", 4}, // dummy value for testing + // this is a dummy response to make sure webui doesn't break + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + {"webui_settings", webui_settings}, + }); + return res; + } + return proxy_get(req); + }; + + this->proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string name = req.get_param("model"); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique<server_http_res>(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, false); + }; + + this->proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique<server_http_res>(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, true); // update last usage for POST request only + }; + + this->post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique<server_http_res>(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); + return res; + } + if (model->status == SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.load(name); + res_ok(res, {{"success", true}}); + return res; + }; + + this->get_router_models = [this](const server_http_req &) { + auto res = std::make_unique<server_http_res>(); + json models_json = json::array(); + auto all_models = models.get_all_meta(); + std::time_t t = std::time(0); + for (const auto & meta : all_models) { + json status { + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, + }; + if (!meta.preset.name.empty()) { + common_preset preset_copy = meta.preset; + unset_reserved_args(preset_copy, false); + preset_copy.unset_option("LLAMA_ARG_HOST"); + preset_copy.unset_option("LLAMA_ARG_PORT"); + preset_copy.unset_option("LLAMA_ARG_ALIAS"); + status["preset"] = preset_copy.to_ini(); + } + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; + } + models_json.push_back(json { + {"id", meta.name}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"status", status}, + // TODO: add other fields, may require reading GGUF metadata + }); + } + res_ok(res, { + {"data", models_json}, + {"object", "list"}, + }); + return res; + }; + + this->post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique<server_http_res>(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (!model->is_active()) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.unload(name); + res_ok(res, {{"success", true}}); + return res; + }; +} + + + +// +// server_http_proxy +// + +// simple implementation of a pipe +// used for streaming data between threads +template<typename T> +struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue<T> queue; + std::atomic<bool> writer_closed{false}; + std::atomic<bool> reader_closed{false}; + void close_write() { + writer_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + bool read(T & output, const std::function<bool()> & should_stop) { + std::unique_lock<std::mutex> lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard<std::mutex> lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } +}; + +static std::string to_lower_copy(const std::string & value) { + std::string lowered(value.size(), '\0'); + std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); + return lowered; +} + +static bool should_strip_proxy_header(const std::string & header_name) { + // Headers that get duplicated when router forwards child responses + if (header_name == "server" || + header_name == "transfer-encoding" || + header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710 + header_name == "keep-alive") { + return true; + } + + // Router injects CORS, child also sends them: duplicate + if (header_name.rfind("access-control-", 0) == 0) { + return true; + } + + return false; +} + +server_http_proxy::server_http_proxy( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map<std::string, std::string> & headers, + const std::string & body, + const std::function<bool()> should_stop, + int32_t timeout_read, + int32_t timeout_write + ) { + // shared between reader and writer threads + auto cli = std::make_shared<httplib::Client>(host, port); + auto pipe = std::make_shared<pipe_t<msg_t>>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server) + cli->set_read_timeout(timeout_write, 0); + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; // false if EOF or pipe broken + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + const auto lowered = to_lower_copy(key); + if (should_strip_proxy_header(lowered)) { + continue; + } + if (lowered == "content-type") { + msg.content_type = value; + continue; + } + msg.headers[key] = value; + } + return pipe->write(std::move(msg)); // send headers first + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + // send data chunks + // returns false if pipe is closed / broken (signal to stop receiving) + return pipe->write({{}, 0, std::string(data, data_length), ""}); + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, "", ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + { + msg_t header; + if (pipe->read(header, should_stop)) { + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = std::move(header.headers); + if (!header.content_type.empty()) { + this->content_type = std::move(header.content_type); + } + } else { + SRV_DBG("%s", "no response headers received (request cancelled?)\n"); + } + } +} |
