1#pragma once
  2
  3#include "common.h"
  4#include "preset.h"
  5#include "server-common.h"
  6#include "server-http.h"
  7
  8#include <mutex>
  9#include <condition_variable>
 10#include <functional>
 11#include <memory>
 12#include <set>
 13
 14/**
 15 * state diagram:
 16 *
 17 * UNLOADED ──► LOADING ──► LOADED
 18 *  ▲            │            │
 19 *  └───failed───┘            │
 20 *  ▲                         │
 21 *  └────────unloaded─────────┘
 22 */
 23enum server_model_status {
 24    // TODO: also add downloading state when the logic is added
 25    SERVER_MODEL_STATUS_UNLOADED,
 26    SERVER_MODEL_STATUS_LOADING,
 27    SERVER_MODEL_STATUS_LOADED
 28};
 29
 30static server_model_status server_model_status_from_string(const std::string & status_str) {
 31    if (status_str == "unloaded") {
 32        return SERVER_MODEL_STATUS_UNLOADED;
 33    }
 34    if (status_str == "loading") {
 35        return SERVER_MODEL_STATUS_LOADING;
 36    }
 37    if (status_str == "loaded") {
 38        return SERVER_MODEL_STATUS_LOADED;
 39    }
 40    throw std::runtime_error("invalid server model status");
 41}
 42
 43static std::string server_model_status_to_string(server_model_status status) {
 44    switch (status) {
 45        case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
 46        case SERVER_MODEL_STATUS_LOADING:  return "loading";
 47        case SERVER_MODEL_STATUS_LOADED:   return "loaded";
 48        default:                           return "unknown";
 49    }
 50}
 51
 52struct server_model_meta {
 53    common_preset preset;
 54    std::string name;
 55    int port = 0;
 56    server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
 57    int64_t last_used = 0; // for LRU unloading
 58    std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
 59    int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
 60    int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
 61
 62    bool is_active() const {
 63        return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
 64    }
 65
 66    bool is_failed() const {
 67        return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
 68    }
 69
 70    void update_args(common_preset_context & ctx_presets, std::string bin_path);
 71};
 72
 73struct subprocess_s;
 74
 75struct server_models {
 76private:
 77    struct instance_t {
 78        std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
 79        std::thread th;
 80        server_model_meta meta;
 81        FILE * stdin_file = nullptr;
 82    };
 83
 84    std::mutex mutex;
 85    std::condition_variable cv;
 86    std::map<std::string, instance_t> mapping;
 87
 88    // for stopping models
 89    std::condition_variable cv_stop;
 90    std::set<std::string> stopping_models;
 91
 92    common_preset_context ctx_preset;
 93
 94    common_params base_params;
 95    std::string bin_path;
 96    std::vector<std::string> base_env;
 97    common_preset base_preset; // base preset from llama-server CLI args
 98
 99    void update_meta(const std::string & name, const server_model_meta & meta);
100
101    // unload least recently used models if the limit is reached
102    void unload_lru();
103
104    // not thread-safe, caller must hold mutex
105    void add_model(server_model_meta && meta);
106
107public:
108    server_models(const common_params & params, int argc, char ** argv);
109
110    void load_models();
111
112    // check if a model instance exists (thread-safe)
113    bool has_model(const std::string & name);
114
115    // return a copy of model metadata (thread-safe)
116    std::optional<server_model_meta> get_meta(const std::string & name);
117
118    // return a copy of all model metadata (thread-safe)
119    std::vector<server_model_meta> get_all_meta();
120
121    // load and unload model instances
122    // these functions are thread-safe
123    void load(const std::string & name);
124    void unload(const std::string & name);
125    void unload_all();
126
127    // update the status of a model instance (thread-safe)
128    void update_status(const std::string & name, server_model_status status, int exit_code);
129
130    // wait until the model instance is fully loaded (thread-safe)
131    // return when the model is loaded or failed to load
132    void wait_until_loaded(const std::string & name);
133
134    // load the model if not loaded, otherwise do nothing (thread-safe)
135    // return false if model is already loaded; return true otherwise (meta may need to be refreshed)
136    bool ensure_model_loaded(const std::string & name);
137
138    // proxy an HTTP request to the model instance
139    server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
140
141    // notify the router server that a model instance is ready
142    // return the monitoring thread (to be joined by the caller)
143    static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
144};
145
146struct server_models_routes {
147    common_params params;
148    json webui_settings = json::object();
149    server_models models;
150    server_models_routes(const common_params & params, int argc, char ** argv)
151            : params(params), models(params, argc, argv) {
152        if (!this->params.webui_config_json.empty()) {
153            try {
154                webui_settings = json::parse(this->params.webui_config_json);
155            } catch (const std::exception & e) {
156                LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
157                throw;
158            }
159        }
160        init_routes();
161    }
162
163    void init_routes();
164    // handlers using lambda function, so that they can capture `this` without `std::bind`
165    server_http_context::handler_t get_router_props;
166    server_http_context::handler_t proxy_get;
167    server_http_context::handler_t proxy_post;
168    server_http_context::handler_t get_router_models;
169    server_http_context::handler_t post_router_models_load;
170    server_http_context::handler_t post_router_models_unload;
171};
172
173/**
174 * A simple HTTP proxy that forwards requests to another server
175 * and relays the responses back.
176 */
177struct server_http_proxy : server_http_res {
178    std::function<void()> cleanup = nullptr;
179public:
180    server_http_proxy(const std::string & method,
181                      const std::string & host,
182                      int port,
183                      const std::string & path,
184                      const std::map<std::string, std::string> & headers,
185                      const std::string & body,
186                      const std::function<bool()> should_stop,
187                      int32_t timeout_read,
188                      int32_t timeout_write
189                      );
190    ~server_http_proxy() {
191        if (cleanup) {
192            cleanup();
193        }
194    }
195private:
196    std::thread thread;
197    struct msg_t {
198        std::map<std::string, std::string> headers;
199        int status = 0;
200        std::string data;
201        std::string content_type;
202    };
203};