summaryrefslogtreecommitdiff
path: root/llama.cpp/tools/server/server-models.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/tools/server/server-models.h')
-rw-r--r--llama.cpp/tools/server/server-models.h203
1 files changed, 203 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/server-models.h b/llama.cpp/tools/server/server-models.h
new file mode 100644
index 0000000..a397abd
--- /dev/null
+++ b/llama.cpp/tools/server/server-models.h
@@ -0,0 +1,203 @@
+#pragma once
+
+#include "common.h"
+#include "preset.h"
+#include "server-common.h"
+#include "server-http.h"
+
+#include <mutex>
+#include <condition_variable>
+#include <functional>
+#include <memory>
+#include <set>
+
+/**
+ * state diagram:
+ *
+ * UNLOADED ──► LOADING ──► LOADED
+ * ▲ │ │
+ * └───failed───┘ │
+ * ▲ │
+ * └────────unloaded─────────┘
+ */
+enum server_model_status {
+ // TODO: also add downloading state when the logic is added
+ SERVER_MODEL_STATUS_UNLOADED,
+ SERVER_MODEL_STATUS_LOADING,
+ SERVER_MODEL_STATUS_LOADED
+};
+
+static server_model_status server_model_status_from_string(const std::string & status_str) {
+ if (status_str == "unloaded") {
+ return SERVER_MODEL_STATUS_UNLOADED;
+ }
+ if (status_str == "loading") {
+ return SERVER_MODEL_STATUS_LOADING;
+ }
+ if (status_str == "loaded") {
+ return SERVER_MODEL_STATUS_LOADED;
+ }
+ throw std::runtime_error("invalid server model status");
+}
+
+static std::string server_model_status_to_string(server_model_status status) {
+ switch (status) {
+ case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
+ case SERVER_MODEL_STATUS_LOADING: return "loading";
+ case SERVER_MODEL_STATUS_LOADED: return "loaded";
+ default: return "unknown";
+ }
+}
+
+struct server_model_meta {
+ common_preset preset;
+ std::string name;
+ int port = 0;
+ server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
+ int64_t last_used = 0; // for LRU unloading
+ std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
+ int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
+ int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
+
+ bool is_active() const {
+ return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
+ }
+
+ bool is_failed() const {
+ return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
+ }
+
+ void update_args(common_preset_context & ctx_presets, std::string bin_path);
+};
+
+struct subprocess_s;
+
+struct server_models {
+private:
+ struct instance_t {
+ std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
+ std::thread th;
+ server_model_meta meta;
+ FILE * stdin_file = nullptr;
+ };
+
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::map<std::string, instance_t> mapping;
+
+ // for stopping models
+ std::condition_variable cv_stop;
+ std::set<std::string> stopping_models;
+
+ common_preset_context ctx_preset;
+
+ common_params base_params;
+ std::string bin_path;
+ std::vector<std::string> base_env;
+ common_preset base_preset; // base preset from llama-server CLI args
+
+ void update_meta(const std::string & name, const server_model_meta & meta);
+
+ // unload least recently used models if the limit is reached
+ void unload_lru();
+
+ // not thread-safe, caller must hold mutex
+ void add_model(server_model_meta && meta);
+
+public:
+ server_models(const common_params & params, int argc, char ** argv);
+
+ void load_models();
+
+ // check if a model instance exists (thread-safe)
+ bool has_model(const std::string & name);
+
+ // return a copy of model metadata (thread-safe)
+ std::optional<server_model_meta> get_meta(const std::string & name);
+
+ // return a copy of all model metadata (thread-safe)
+ std::vector<server_model_meta> get_all_meta();
+
+ // load and unload model instances
+ // these functions are thread-safe
+ void load(const std::string & name);
+ void unload(const std::string & name);
+ void unload_all();
+
+ // update the status of a model instance (thread-safe)
+ void update_status(const std::string & name, server_model_status status, int exit_code);
+
+ // wait until the model instance is fully loaded (thread-safe)
+ // return when the model is loaded or failed to load
+ void wait_until_loaded(const std::string & name);
+
+ // load the model if not loaded, otherwise do nothing (thread-safe)
+ // return false if model is already loaded; return true otherwise (meta may need to be refreshed)
+ bool ensure_model_loaded(const std::string & name);
+
+ // proxy an HTTP request to the model instance
+ server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
+
+ // notify the router server that a model instance is ready
+ // return the monitoring thread (to be joined by the caller)
+ static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
+};
+
+struct server_models_routes {
+ common_params params;
+ json webui_settings = json::object();
+ server_models models;
+ server_models_routes(const common_params & params, int argc, char ** argv)
+ : params(params), models(params, argc, argv) {
+ if (!this->params.webui_config_json.empty()) {
+ try {
+ webui_settings = json::parse(this->params.webui_config_json);
+ } catch (const std::exception & e) {
+ LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
+ throw;
+ }
+ }
+ init_routes();
+ }
+
+ void init_routes();
+ // handlers using lambda function, so that they can capture `this` without `std::bind`
+ server_http_context::handler_t get_router_props;
+ server_http_context::handler_t proxy_get;
+ server_http_context::handler_t proxy_post;
+ server_http_context::handler_t get_router_models;
+ server_http_context::handler_t post_router_models_load;
+ server_http_context::handler_t post_router_models_unload;
+};
+
+/**
+ * A simple HTTP proxy that forwards requests to another server
+ * and relays the responses back.
+ */
+struct server_http_proxy : server_http_res {
+ std::function<void()> cleanup = nullptr;
+public:
+ server_http_proxy(const std::string & method,
+ const std::string & host,
+ int port,
+ const std::string & path,
+ const std::map<std::string, std::string> & headers,
+ const std::string & body,
+ const std::function<bool()> should_stop,
+ int32_t timeout_read,
+ int32_t timeout_write
+ );
+ ~server_http_proxy() {
+ if (cleanup) {
+ cleanup();
+ }
+ }
+private:
+ std::thread thread;
+ struct msg_t {
+ std::map<std::string, std::string> headers;
+ int status = 0;
+ std::string data;
+ std::string content_type;
+ };
+};