1#include "common.h"
  2#include "server-http.h"
  3#include "server-common.h"
  4
  5#include <cpp-httplib/httplib.h>
  6
  7#include <functional>
  8#include <string>
  9#include <thread>
 10
 11// auto generated files (see README.md for details)
 12#include "index.html.gz.hpp"
 13#include "loading.html.hpp"
 14
 15//
 16// HTTP implementation using cpp-httplib
 17//
 18
 19class server_http_context::Impl {
 20public:
 21    std::unique_ptr<httplib::Server> srv;
 22};
 23
 24server_http_context::server_http_context()
 25    : pimpl(std::make_unique<server_http_context::Impl>())
 26{}
 27
 28server_http_context::~server_http_context() = default;
 29
 30static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
 31    // skip logging requests that are regularly sent, to avoid log spam
 32    if (req.path == "/health"
 33        || req.path == "/v1/health"
 34        || req.path == "/models"
 35        || req.path == "/v1/models"
 36        || req.path == "/props"
 37        || req.path == "/metrics"
 38    ) {
 39        return;
 40    }
 41
 42    // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
 43
 44    SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
 45
 46    SRV_DBG("request:  %s\n", req.body.c_str());
 47    SRV_DBG("response: %s\n", res.body.c_str());
 48}
 49
 50bool server_http_context::init(const common_params & params) {
 51    path_prefix = params.api_prefix;
 52    port = params.port;
 53    hostname = params.hostname;
 54
 55    auto & srv = pimpl->srv;
 56
 57#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 58    if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
 59        LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
 60        srv.reset(
 61            new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
 62        );
 63    } else {
 64        LOG_INF("Running without SSL\n");
 65        srv.reset(new httplib::Server());
 66    }
 67#else
 68    if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
 69        LOG_ERR("Server is built without SSL support\n");
 70        return false;
 71    }
 72    srv.reset(new httplib::Server());
 73#endif
 74
 75    srv->set_default_headers({{"Server", "llama.cpp"}});
 76    srv->set_logger(log_server_request);
 77    srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
 78        // this is fail-safe; exceptions should already handled by `ex_wrapper`
 79
 80        std::string message;
 81        try {
 82            std::rethrow_exception(ep);
 83        } catch (const std::exception & e) {
 84            message = e.what();
 85        } catch (...) {
 86            message = "Unknown Exception";
 87        }
 88
 89        res.status = 500;
 90        res.set_content(message, "text/plain");
 91        LOG_ERR("got exception: %s\n", message.c_str());
 92    });
 93
 94    srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
 95        if (res.status == 404) {
 96            res.set_content(
 97                safe_json_to_str(json {
 98                    {"error", {
 99                        {"message", "File Not Found"},
100                        {"type", "not_found_error"},
101                        {"code", 404}
102                    }}
103                }),
104                "application/json; charset=utf-8"
105            );
106        }
107        // for other error codes, we skip processing here because it's already done by res->error()
108    });
109
110    // set timeouts and change hostname and port
111    srv->set_read_timeout (params.timeout_read);
112    srv->set_write_timeout(params.timeout_write);
113
114    if (params.api_keys.size() == 1) {
115        auto key = params.api_keys[0];
116        std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
117        LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
118    } else if (params.api_keys.size() > 1) {
119        LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
120    }
121
122    //
123    // Middlewares
124    //
125
126    auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
127        static const std::unordered_set<std::string> public_endpoints = {
128            "/health",
129            "/v1/health",
130            "/models",
131            "/v1/models",
132            "/api/tags"
133        };
134
135        // If API key is not set, skip validation
136        if (api_keys.empty()) {
137            return true;
138        }
139
140        // If path is public or is static file, skip validation
141        if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
142            return true;
143        }
144
145        // Check for API key in the Authorization header
146        std::string req_api_key = req.get_header_value("Authorization");
147        if (req_api_key.empty()) {
148            // retry with anthropic header
149            req_api_key = req.get_header_value("X-Api-Key");
150        }
151
152        // remove the "Bearer " prefix if needed
153        std::string prefix = "Bearer ";
154        if (req_api_key.substr(0, prefix.size()) == prefix) {
155            req_api_key = req_api_key.substr(prefix.size());
156        }
157
158        // validate the API key
159        if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
160            return true; // API key is valid
161        }
162
163        // API key is invalid or not provided
164        res.status = 401;
165        res.set_content(
166            safe_json_to_str(json {
167                {"error", {
168                    {"message", "Invalid API Key"},
169                    {"type", "authentication_error"},
170                    {"code", 401}
171                }}
172            }),
173            "application/json; charset=utf-8"
174        );
175
176        LOG_WRN("Unauthorized: Invalid API Key\n");
177
178        return false;
179    };
180
181    auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
182        bool ready = is_ready.load();
183        if (!ready) {
184            auto tmp = string_split<std::string>(req.path, '.');
185            if (req.path == "/" || tmp.back() == "html") {
186                res.status = 503;
187                res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
188            } else {
189                // no endpoints is allowed to be accessed when the server is not ready
190                // this is to prevent any data races or inconsistent states
191                res.status = 503;
192                res.set_content(
193                    safe_json_to_str(json {
194                        {"error", {
195                            {"message", "Loading model"},
196                            {"type", "unavailable_error"},
197                            {"code", 503}
198                        }}
199                    }),
200                    "application/json; charset=utf-8"
201                );
202            }
203            return false;
204        }
205        return true;
206    };
207
208    // register server middlewares
209    srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
210        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
211        // If this is OPTIONS request, skip validation because browsers don't include Authorization header
212        if (req.method == "OPTIONS") {
213            res.set_header("Access-Control-Allow-Credentials", "true");
214            res.set_header("Access-Control-Allow-Methods",     "GET, POST");
215            res.set_header("Access-Control-Allow-Headers",     "*");
216            res.set_content("", "text/html"); // blank response, no data
217            return httplib::Server::HandlerResponse::Handled; // skip further processing
218        }
219        if (!middleware_server_state(req, res)) {
220            return httplib::Server::HandlerResponse::Handled;
221        }
222        if (!middleware_validate_api_key(req, res)) {
223            return httplib::Server::HandlerResponse::Handled;
224        }
225        return httplib::Server::HandlerResponse::Unhandled;
226    });
227
228    int n_threads_http = params.n_threads_http;
229    if (n_threads_http < 1) {
230        // +2 threads for monitoring endpoints
231        n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
232    }
233    LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
234    srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
235
236    //
237    // Web UI setup
238    //
239
240    if (!params.webui) {
241        LOG_INF("Web UI is disabled\n");
242    } else {
243        // register static assets routes
244        if (!params.public_path.empty()) {
245            // Set the base directory for serving static files
246            bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
247            if (!is_found) {
248                LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
249                return 1;
250            }
251        } else {
252            // using embedded static index.html
253            srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
254                if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
255                    res.set_content("Error: gzip is not supported by this browser", "text/plain");
256                } else {
257                    res.set_header("Content-Encoding", "gzip");
258                    // COEP and COOP headers, required by pyodide (python interpreter)
259                    res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
260                    res.set_header("Cross-Origin-Opener-Policy", "same-origin");
261                    res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
262                }
263                return false;
264            });
265        }
266    }
267    return true;
268}
269
270bool server_http_context::start() {
271    // Bind and listen
272
273    auto & srv = pimpl->srv;
274    bool was_bound = false;
275    bool is_sock = false;
276    if (string_ends_with(std::string(hostname), ".sock")) {
277        is_sock = true;
278        LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
279        srv->set_address_family(AF_UNIX);
280        // bind_to_port requires a second arg, any value other than 0 should
281        // simply get ignored
282        was_bound = srv->bind_to_port(hostname, 8080);
283    } else {
284        LOG_INF("%s: binding port with default address family\n", __func__);
285        // bind HTTP listen port
286        if (port == 0) {
287            int bound_port = srv->bind_to_any_port(hostname);
288            was_bound = (bound_port >= 0);
289            if (was_bound) {
290                port = bound_port;
291            }
292        } else {
293            was_bound = srv->bind_to_port(hostname, port);
294        }
295    }
296
297    if (!was_bound) {
298        LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
299        return false;
300    }
301
302    // run the HTTP server in a thread
303    thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
304    srv->wait_until_ready();
305
306    listening_address = is_sock ? string_format("unix://%s",    hostname.c_str())
307                                : string_format("http://%s:%d", hostname.c_str(), port);
308    return true;
309}
310
311void server_http_context::stop() const {
312    if (pimpl->srv) {
313        pimpl->srv->stop();
314    }
315}
316
317static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
318    for (const auto & [key, value] : headers) {
319        res.set_header(key, value);
320    }
321}
322
323static std::map<std::string, std::string> get_params(const httplib::Request & req) {
324    std::map<std::string, std::string> params;
325    for (const auto & [key, value] : req.params) {
326        params[key] = value;
327    }
328    for (const auto & [key, value] : req.path_params) {
329        params[key] = value;
330    }
331    return params;
332}
333
334static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
335    std::map<std::string, std::string> headers;
336    for (const auto & [key, value] : req.headers) {
337        headers[key] = value;
338    }
339    return headers;
340}
341
342// using unique_ptr for request to allow safe capturing in lambdas
343using server_http_req_ptr = std::unique_ptr<server_http_req>;
344
345static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
346    if (response->is_stream()) {
347        res.status = response->status;
348        set_headers(res, response->headers);
349        std::string content_type = response->content_type;
350        // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
351        std::shared_ptr<server_http_req> q_ptr = std::move(request);
352        std::shared_ptr<server_http_res> r_ptr = std::move(response);
353        const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
354            std::string chunk;
355            bool has_next = response->next(chunk);
356            if (!chunk.empty()) {
357                // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
358                sink.write(chunk.data(), chunk.size());
359                SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
360            }
361            if (!has_next) {
362                sink.done();
363                SRV_DBG("%s", "http: stream ended\n");
364            }
365            return has_next;
366        };
367        const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
368            response.reset(); // trigger the destruction of the response object
369            request.reset();  // trigger the destruction of the request object
370        };
371        res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
372    } else {
373        res.status = response->status;
374        set_headers(res, response->headers);
375        res.set_content(response->data, response->content_type);
376    }
377}
378
379void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
380    pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
381        server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
382            get_params(req),
383            get_headers(req),
384            req.path,
385            req.body,
386            req.is_connection_closed
387        });
388        server_http_res_ptr response = handler(*request);
389        process_handler_response(std::move(request), response, res);
390    });
391}
392
393void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
394    pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
395        server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
396            get_params(req),
397            get_headers(req),
398            req.path,
399            req.body,
400            req.is_connection_closed
401        });
402        server_http_res_ptr response = handler(*request);
403        process_handler_response(std::move(request), response, res);
404    });
405}
406