diff options
Diffstat (limited to 'llama.cpp/tools/server/server-models.h')
| -rw-r--r-- | llama.cpp/tools/server/server-models.h | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/server-models.h b/llama.cpp/tools/server/server-models.h new file mode 100644 index 0000000..a397abd --- /dev/null +++ b/llama.cpp/tools/server/server-models.h @@ -0,0 +1,203 @@ +#pragma once + +#include "common.h" +#include "preset.h" +#include "server-common.h" +#include "server-http.h" + +#include <mutex> +#include <condition_variable> +#include <functional> +#include <memory> +#include <set> + +/** + * state diagram: + * + * UNLOADED ──► LOADING ──► LOADED + * ▲ │ │ + * └───failed───┘ │ + * ▲ │ + * └────────unloaded─────────┘ + */ +enum server_model_status { + // TODO: also add downloading state when the logic is added + SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_LOADING, + SERVER_MODEL_STATUS_LOADED +}; + +static server_model_status server_model_status_from_string(const std::string & status_str) { + if (status_str == "unloaded") { + return SERVER_MODEL_STATUS_UNLOADED; + } + if (status_str == "loading") { + return SERVER_MODEL_STATUS_LOADING; + } + if (status_str == "loaded") { + return SERVER_MODEL_STATUS_LOADED; + } + throw std::runtime_error("invalid server model status"); +} + +static std::string server_model_status_to_string(server_model_status status) { + switch (status) { + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + default: return "unknown"; + } +} + +struct server_model_meta { + common_preset preset; + std::string name; + int port = 0; + server_model_status status = SERVER_MODEL_STATUS_UNLOADED; + int64_t last_used = 0; // for LRU unloading + std::vector<std::string> args; // args passed to the model instance, will be populated by render_args() + int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) + int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown + + bool is_active() const { + return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; + } + + bool is_failed() const { + return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; + } + + void update_args(common_preset_context & ctx_presets, std::string bin_path); +}; + +struct subprocess_s; + +struct server_models { +private: + struct instance_t { + std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread + std::thread th; + server_model_meta meta; + FILE * stdin_file = nullptr; + }; + + std::mutex mutex; + std::condition_variable cv; + std::map<std::string, instance_t> mapping; + + // for stopping models + std::condition_variable cv_stop; + std::set<std::string> stopping_models; + + common_preset_context ctx_preset; + + common_params base_params; + std::string bin_path; + std::vector<std::string> base_env; + common_preset base_preset; // base preset from llama-server CLI args + + void update_meta(const std::string & name, const server_model_meta & meta); + + // unload least recently used models if the limit is reached + void unload_lru(); + + // not thread-safe, caller must hold mutex + void add_model(server_model_meta && meta); + +public: + server_models(const common_params & params, int argc, char ** argv); + + void load_models(); + + // check if a model instance exists (thread-safe) + bool has_model(const std::string & name); + + // return a copy of model metadata (thread-safe) + std::optional<server_model_meta> get_meta(const std::string & name); + + // return a copy of all model metadata (thread-safe) + std::vector<server_model_meta> get_all_meta(); + + // load and unload model instances + // these functions are thread-safe + void load(const std::string & name); + void unload(const std::string & name); + void unload_all(); + + // update the status of a model instance (thread-safe) + void update_status(const std::string & name, server_model_status status, int exit_code); + + // wait until the model instance is fully loaded (thread-safe) + // return when the model is loaded or failed to load + void wait_until_loaded(const std::string & name); + + // load the model if not loaded, otherwise do nothing (thread-safe) + // return false if model is already loaded; return true otherwise (meta may need to be refreshed) + bool ensure_model_loaded(const std::string & name); + + // proxy an HTTP request to the model instance + server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + + // notify the router server that a model instance is ready + // return the monitoring thread (to be joined by the caller) + static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler); +}; + +struct server_models_routes { + common_params params; + json webui_settings = json::object(); + server_models models; + server_models_routes(const common_params & params, int argc, char ** argv) + : params(params), models(params, argc, argv) { + if (!this->params.webui_config_json.empty()) { + try { + webui_settings = json::parse(this->params.webui_config_json); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what()); + throw; + } + } + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_router_props; + server_http_context::handler_t proxy_get; + server_http_context::handler_t proxy_post; + server_http_context::handler_t get_router_models; + server_http_context::handler_t post_router_models_load; + server_http_context::handler_t post_router_models_unload; +}; + +/** + * A simple HTTP proxy that forwards requests to another server + * and relays the responses back. + */ +struct server_http_proxy : server_http_res { + std::function<void()> cleanup = nullptr; +public: + server_http_proxy(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map<std::string, std::string> & headers, + const std::string & body, + const std::function<bool()> should_stop, + int32_t timeout_read, + int32_t timeout_write + ); + ~server_http_proxy() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map<std::string, std::string> headers; + int status = 0; + std::string data; + std::string content_type; + }; +}; |
