1#pragma once
2
3#include "common.h"
4#include "preset.h"
5#include "server-common.h"
6#include "server-http.h"
7
8#include <mutex>
9#include <condition_variable>
10#include <functional>
11#include <memory>
12#include <set>
13
14/**
15 * state diagram:
16 *
17 * UNLOADED ──► LOADING ──► LOADED
18 * ▲ │ │
19 * └───failed───┘ │
20 * ▲ │
21 * └────────unloaded─────────┘
22 */
23enum server_model_status {
24 // TODO: also add downloading state when the logic is added
25 SERVER_MODEL_STATUS_UNLOADED,
26 SERVER_MODEL_STATUS_LOADING,
27 SERVER_MODEL_STATUS_LOADED
28};
29
30static server_model_status server_model_status_from_string(const std::string & status_str) {
31 if (status_str == "unloaded") {
32 return SERVER_MODEL_STATUS_UNLOADED;
33 }
34 if (status_str == "loading") {
35 return SERVER_MODEL_STATUS_LOADING;
36 }
37 if (status_str == "loaded") {
38 return SERVER_MODEL_STATUS_LOADED;
39 }
40 throw std::runtime_error("invalid server model status");
41}
42
43static std::string server_model_status_to_string(server_model_status status) {
44 switch (status) {
45 case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
46 case SERVER_MODEL_STATUS_LOADING: return "loading";
47 case SERVER_MODEL_STATUS_LOADED: return "loaded";
48 default: return "unknown";
49 }
50}
51
52struct server_model_meta {
53 common_preset preset;
54 std::string name;
55 int port = 0;
56 server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
57 int64_t last_used = 0; // for LRU unloading
58 std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
59 int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
60 int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
61
62 bool is_active() const {
63 return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
64 }
65
66 bool is_failed() const {
67 return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
68 }
69
70 void update_args(common_preset_context & ctx_presets, std::string bin_path);
71};
72
73struct subprocess_s;
74
75struct server_models {
76private:
77 struct instance_t {
78 std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
79 std::thread th;
80 server_model_meta meta;
81 FILE * stdin_file = nullptr;
82 };
83
84 std::mutex mutex;
85 std::condition_variable cv;
86 std::map<std::string, instance_t> mapping;
87
88 // for stopping models
89 std::condition_variable cv_stop;
90 std::set<std::string> stopping_models;
91
92 common_preset_context ctx_preset;
93
94 common_params base_params;
95 std::string bin_path;
96 std::vector<std::string> base_env;
97 common_preset base_preset; // base preset from llama-server CLI args
98
99 void update_meta(const std::string & name, const server_model_meta & meta);
100
101 // unload least recently used models if the limit is reached
102 void unload_lru();
103
104 // not thread-safe, caller must hold mutex
105 void add_model(server_model_meta && meta);
106
107public:
108 server_models(const common_params & params, int argc, char ** argv);
109
110 void load_models();
111
112 // check if a model instance exists (thread-safe)
113 bool has_model(const std::string & name);
114
115 // return a copy of model metadata (thread-safe)
116 std::optional<server_model_meta> get_meta(const std::string & name);
117
118 // return a copy of all model metadata (thread-safe)
119 std::vector<server_model_meta> get_all_meta();
120
121 // load and unload model instances
122 // these functions are thread-safe
123 void load(const std::string & name);
124 void unload(const std::string & name);
125 void unload_all();
126
127 // update the status of a model instance (thread-safe)
128 void update_status(const std::string & name, server_model_status status, int exit_code);
129
130 // wait until the model instance is fully loaded (thread-safe)
131 // return when the model is loaded or failed to load
132 void wait_until_loaded(const std::string & name);
133
134 // load the model if not loaded, otherwise do nothing (thread-safe)
135 // return false if model is already loaded; return true otherwise (meta may need to be refreshed)
136 bool ensure_model_loaded(const std::string & name);
137
138 // proxy an HTTP request to the model instance
139 server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
140
141 // notify the router server that a model instance is ready
142 // return the monitoring thread (to be joined by the caller)
143 static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
144};
145
146struct server_models_routes {
147 common_params params;
148 json webui_settings = json::object();
149 server_models models;
150 server_models_routes(const common_params & params, int argc, char ** argv)
151 : params(params), models(params, argc, argv) {
152 if (!this->params.webui_config_json.empty()) {
153 try {
154 webui_settings = json::parse(this->params.webui_config_json);
155 } catch (const std::exception & e) {
156 LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
157 throw;
158 }
159 }
160 init_routes();
161 }
162
163 void init_routes();
164 // handlers using lambda function, so that they can capture `this` without `std::bind`
165 server_http_context::handler_t get_router_props;
166 server_http_context::handler_t proxy_get;
167 server_http_context::handler_t proxy_post;
168 server_http_context::handler_t get_router_models;
169 server_http_context::handler_t post_router_models_load;
170 server_http_context::handler_t post_router_models_unload;
171};
172
173/**
174 * A simple HTTP proxy that forwards requests to another server
175 * and relays the responses back.
176 */
177struct server_http_proxy : server_http_res {
178 std::function<void()> cleanup = nullptr;
179public:
180 server_http_proxy(const std::string & method,
181 const std::string & host,
182 int port,
183 const std::string & path,
184 const std::map<std::string, std::string> & headers,
185 const std::string & body,
186 const std::function<bool()> should_stop,
187 int32_t timeout_read,
188 int32_t timeout_write
189 );
190 ~server_http_proxy() {
191 if (cleanup) {
192 cleanup();
193 }
194 }
195private:
196 std::thread thread;
197 struct msg_t {
198 std::map<std::string, std::string> headers;
199 int status = 0;
200 std::string data;
201 std::string content_type;
202 };
203};