1#include "server-common.h"
   2#include "server-models.h"
   3
   4#include "preset.h"
   5#include "download.h"
   6
   7#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
   8#include <sheredom/subprocess.h>
   9
  10#include <functional>
  11#include <algorithm>
  12#include <thread>
  13#include <mutex>
  14#include <condition_variable>
  15#include <cstring>
  16#include <atomic>
  17#include <chrono>
  18#include <queue>
  19#include <filesystem>
  20#include <cstring>
  21
  22#ifdef _WIN32
  23#include <winsock2.h>
  24#include <windows.h>
  25#else
  26#include <sys/socket.h>
  27#include <netinet/in.h>
  28#include <arpa/inet.h>
  29#include <unistd.h>
  30extern char **environ;
  31#endif
  32
  33#if defined(__APPLE__) && defined(__MACH__)
  34// macOS: use _NSGetExecutablePath to get the executable path
  35#include <mach-o/dyld.h>
  36#include <limits.h>
  37#endif
  38
  39#define DEFAULT_STOP_TIMEOUT 10 // seconds
  40
  41#define CMD_ROUTER_TO_CHILD_EXIT  "cmd_router_to_child:exit"
  42#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
  43
  44// address for child process, this is needed because router may run on 0.0.0.0
  45// ref: https://github.com/ggml-org/llama.cpp/issues/17862
  46#define CHILD_ADDR "127.0.0.1"
  47
  48static std::filesystem::path get_server_exec_path() {
  49#if defined(_WIN32)
  50    wchar_t buf[32768] = { 0 };  // Large buffer to handle long paths
  51    DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  52    if (len == 0 || len >= _countof(buf)) {
  53        throw std::runtime_error("GetModuleFileNameW failed or path too long");
  54    }
  55    return std::filesystem::path(buf);
  56#elif defined(__APPLE__) && defined(__MACH__)
  57    char small_path[PATH_MAX];
  58    uint32_t size = sizeof(small_path);
  59
  60    if (_NSGetExecutablePath(small_path, &size) == 0) {
  61        // resolve any symlinks to get absolute path
  62        try {
  63            return std::filesystem::canonical(std::filesystem::path(small_path));
  64        } catch (...) {
  65            return std::filesystem::path(small_path);
  66        }
  67    } else {
  68        // buffer was too small, allocate required size and call again
  69        std::vector<char> buf(size);
  70        if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  71            try {
  72                return std::filesystem::canonical(std::filesystem::path(buf.data()));
  73            } catch (...) {
  74                return std::filesystem::path(buf.data());
  75            }
  76        }
  77        throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  78    }
  79#else
  80    char path[FILENAME_MAX];
  81    ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  82    if (count <= 0) {
  83        throw std::runtime_error("failed to resolve /proc/self/exe");
  84    }
  85    return std::filesystem::path(std::string(path, count));
  86#endif
  87}
  88
  89static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
  90    preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
  91    preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
  92    preset.unset_option("LLAMA_API_KEY");
  93    preset.unset_option("LLAMA_ARG_MODELS_DIR");
  94    preset.unset_option("LLAMA_ARG_MODELS_MAX");
  95    preset.unset_option("LLAMA_ARG_MODELS_PRESET");
  96    preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
  97    if (unset_model_args) {
  98        preset.unset_option("LLAMA_ARG_MODEL");
  99        preset.unset_option("LLAMA_ARG_MMPROJ");
 100        preset.unset_option("LLAMA_ARG_HF_REPO");
 101    }
 102}
 103
 104#ifdef _WIN32
 105static std::string wide_to_utf8(const wchar_t * ws) {
 106    if (!ws || !*ws) {
 107        return {};
 108    }
 109
 110    const int len = static_cast<int>(std::wcslen(ws));
 111    const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr);
 112    if (bytes == 0) {
 113        return {};
 114    }
 115
 116    std::string utf8(bytes, '\0');
 117    WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr);
 118
 119    return utf8;
 120}
 121#endif
 122
 123static std::vector<std::string> get_environment() {
 124    std::vector<std::string> env;
 125
 126#ifdef _WIN32
 127    LPWCH env_block = GetEnvironmentStringsW();
 128    if (!env_block) {
 129        return env;
 130    }
 131    for (LPWCH e = env_block; *e; e += wcslen(e) + 1) {
 132        env.emplace_back(wide_to_utf8(e));
 133    }
 134    FreeEnvironmentStringsW(env_block);
 135#else
 136    if (environ == nullptr) {
 137        return env;
 138    }
 139    for (char ** e = environ; *e != nullptr; e++) {
 140        env.emplace_back(*e);
 141    }
 142#endif
 143
 144    return env;
 145}
 146
 147void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
 148    // update params
 149    unset_reserved_args(preset, false);
 150    preset.set_option(ctx_preset, "LLAMA_ARG_HOST",  CHILD_ADDR);
 151    preset.set_option(ctx_preset, "LLAMA_ARG_PORT",  std::to_string(port));
 152    preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
 153    // TODO: maybe validate preset before rendering ?
 154    // render args
 155    args = preset.to_args(bin_path);
 156}
 157
 158//
 159// server_models
 160//
 161
 162server_models::server_models(
 163        const common_params & params,
 164        int argc,
 165        char ** argv)
 166            : ctx_preset(LLAMA_EXAMPLE_SERVER),
 167              base_params(params),
 168              base_env(get_environment()),
 169              base_preset(ctx_preset.load_from_args(argc, argv)) {
 170    // clean up base preset
 171    unset_reserved_args(base_preset, true);
 172    // set binary path
 173    try {
 174        bin_path = get_server_exec_path().string();
 175    } catch (const std::exception & e) {
 176        bin_path = argv[0];
 177        LOG_WRN("failed to get server executable path: %s\n", e.what());
 178        LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
 179    }
 180    load_models();
 181}
 182
 183void server_models::add_model(server_model_meta && meta) {
 184    if (mapping.find(meta.name) != mapping.end()) {
 185        throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
 186    }
 187    meta.update_args(ctx_preset, bin_path); // render args
 188    std::string name = meta.name;
 189    mapping[name] = instance_t{
 190        /* subproc */ std::make_shared<subprocess_s>(),
 191        /* th      */ std::thread(),
 192        /* meta    */ std::move(meta)
 193    };
 194}
 195
 196// TODO: allow refreshing cached model list
 197void server_models::load_models() {
 198    // loading models from 3 sources:
 199    // 1. cached models
 200    common_presets cached_models = ctx_preset.load_from_cache();
 201    SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
 202    // 2. local models from --models-dir
 203    common_presets local_models;
 204    if (!base_params.models_dir.empty()) {
 205        local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
 206        SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
 207    }
 208    // 3. custom-path models from presets
 209    common_preset global = {};
 210    common_presets custom_presets = {};
 211    if (!base_params.models_preset.empty()) {
 212        custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
 213        SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
 214    }
 215
 216    // cascade, apply global preset first
 217    cached_models  = ctx_preset.cascade(global, cached_models);
 218    local_models   = ctx_preset.cascade(global, local_models);
 219    custom_presets = ctx_preset.cascade(global, custom_presets);
 220
 221    // note: if a model exists in both cached and local, local takes precedence
 222    common_presets final_presets;
 223    for (const auto & [name, preset] : cached_models) {
 224        final_presets[name] = preset;
 225    }
 226    for (const auto & [name, preset] : local_models) {
 227        final_presets[name] = preset;
 228    }
 229
 230    // process custom presets from INI
 231    for (const auto & [name, custom] : custom_presets) {
 232        if (final_presets.find(name) != final_presets.end()) {
 233            // apply custom config if exists
 234            common_preset & target = final_presets[name];
 235            target.merge(custom);
 236        } else {
 237            // otherwise add directly
 238            final_presets[name] = custom;
 239        }
 240    }
 241
 242    // server base preset from CLI args take highest precedence
 243    for (auto & [name, preset] : final_presets) {
 244        preset.merge(base_preset);
 245    }
 246
 247    // convert presets to server_model_meta and add to mapping
 248    for (const auto & preset : final_presets) {
 249        server_model_meta meta{
 250            /* preset       */ preset.second,
 251            /* name         */ preset.first,
 252            /* port         */ 0,
 253            /* status       */ SERVER_MODEL_STATUS_UNLOADED,
 254            /* last_used    */ 0,
 255            /* args         */ std::vector<std::string>(),
 256            /* exit_code    */ 0,
 257            /* stop_timeout */ DEFAULT_STOP_TIMEOUT,
 258        };
 259        add_model(std::move(meta));
 260    }
 261
 262    // log available models
 263    {
 264        std::unordered_set<std::string> custom_names;
 265        for (const auto & [name, preset] : custom_presets) {
 266            custom_names.insert(name);
 267        }
 268        SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
 269        for (const auto & [name, inst] : mapping) {
 270            bool has_custom = custom_names.find(name) != custom_names.end();
 271            SRV_INF("  %c %s\n", has_custom ? '*' : ' ', name.c_str());
 272        }
 273    }
 274
 275    // handle custom stop-timeout option
 276    for (auto & [name, inst] : mapping) {
 277        std::string val;
 278        if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
 279            try {
 280                inst.meta.stop_timeout = std::stoi(val);
 281            } catch (...) {
 282                SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
 283                    val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
 284                inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
 285            }
 286        }
 287    }
 288
 289    // load any autoload models
 290    std::vector<std::string> models_to_load;
 291    for (const auto & [name, inst] : mapping) {
 292        std::string val;
 293        if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
 294            models_to_load.push_back(name);
 295        }
 296    }
 297    if ((int)models_to_load.size() > base_params.models_max) {
 298        throw std::runtime_error(string_format(
 299            "number of models to load on startup (%zu) exceeds models_max (%d)",
 300            models_to_load.size(),
 301            base_params.models_max
 302        ));
 303    }
 304    for (const auto & name : models_to_load) {
 305        SRV_INF("(startup) loading model %s\n", name.c_str());
 306        load(name);
 307    }
 308}
 309
 310void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
 311    std::lock_guard<std::mutex> lk(mutex);
 312    auto it = mapping.find(name);
 313    if (it != mapping.end()) {
 314        it->second.meta = meta;
 315    }
 316    cv.notify_all(); // notify wait_until_loaded
 317}
 318
 319bool server_models::has_model(const std::string & name) {
 320    std::lock_guard<std::mutex> lk(mutex);
 321    return mapping.find(name) != mapping.end();
 322}
 323
 324std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
 325    std::lock_guard<std::mutex> lk(mutex);
 326    auto it = mapping.find(name);
 327    if (it != mapping.end()) {
 328        return it->second.meta;
 329    }
 330    return std::nullopt;
 331}
 332
 333static int get_free_port() {
 334#ifdef _WIN32
 335    WSADATA wsaData;
 336    if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
 337        return -1;
 338    }
 339    typedef SOCKET native_socket_t;
 340#define INVALID_SOCKET_VAL INVALID_SOCKET
 341#define CLOSE_SOCKET(s) closesocket(s)
 342#else
 343    typedef int native_socket_t;
 344#define INVALID_SOCKET_VAL -1
 345#define CLOSE_SOCKET(s) close(s)
 346#endif
 347
 348    native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
 349    if (sock == INVALID_SOCKET_VAL) {
 350#ifdef _WIN32
 351        WSACleanup();
 352#endif
 353        return -1;
 354    }
 355
 356    struct sockaddr_in serv_addr;
 357    std::memset(&serv_addr, 0, sizeof(serv_addr));
 358    serv_addr.sin_family = AF_INET;
 359    serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
 360    serv_addr.sin_port = htons(0);
 361
 362    if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
 363        CLOSE_SOCKET(sock);
 364#ifdef _WIN32
 365        WSACleanup();
 366#endif
 367        return -1;
 368    }
 369
 370#ifdef _WIN32
 371    int namelen = sizeof(serv_addr);
 372#else
 373    socklen_t namelen = sizeof(serv_addr);
 374#endif
 375    if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
 376        CLOSE_SOCKET(sock);
 377#ifdef _WIN32
 378        WSACleanup();
 379#endif
 380        return -1;
 381    }
 382
 383    int port = ntohs(serv_addr.sin_port);
 384
 385    CLOSE_SOCKET(sock);
 386#ifdef _WIN32
 387    WSACleanup();
 388#endif
 389
 390    return port;
 391}
 392
 393// helper to convert vector<string> to char **
 394// pointers are only valid as long as the original vector is valid
 395static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
 396    std::vector<char *> result;
 397    result.reserve(vec.size() + 1);
 398    for (const auto & s : vec) {
 399        result.push_back(const_cast<char*>(s.c_str()));
 400    }
 401    result.push_back(nullptr);
 402    return result;
 403}
 404
 405std::vector<server_model_meta> server_models::get_all_meta() {
 406    std::lock_guard<std::mutex> lk(mutex);
 407    std::vector<server_model_meta> result;
 408    result.reserve(mapping.size());
 409    for (const auto & [name, inst] : mapping) {
 410        result.push_back(inst.meta);
 411    }
 412    return result;
 413}
 414
 415void server_models::unload_lru() {
 416    if (base_params.models_max <= 0) {
 417        return; // no limit
 418    }
 419    // remove one of the servers if we passed the models_max (least recently used - LRU)
 420    std::string lru_model_name = "";
 421    int64_t lru_last_used = ggml_time_ms();
 422    size_t count_active = 0;
 423    {
 424        std::unique_lock<std::mutex> lk(mutex);
 425        for (const auto & m : mapping) {
 426            if (m.second.meta.is_active()) {
 427                count_active++;
 428                if (m.second.meta.last_used < lru_last_used) {
 429                    lru_model_name = m.first;
 430                    lru_last_used = m.second.meta.last_used;
 431                }
 432            }
 433        }
 434    }
 435    if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
 436        SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
 437        unload(lru_model_name);
 438        // wait for unload to complete
 439        {
 440            std::unique_lock<std::mutex> lk(mutex);
 441            cv.wait(lk, [this, &lru_model_name]() {
 442                return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
 443            });
 444        }
 445    }
 446}
 447
 448void server_models::load(const std::string & name) {
 449    if (!has_model(name)) {
 450        throw std::runtime_error("model name=" + name + " is not found");
 451    }
 452    unload_lru();
 453
 454    std::lock_guard<std::mutex> lk(mutex);
 455
 456    auto meta = mapping[name].meta;
 457    if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
 458        SRV_INF("model %s is not ready\n", name.c_str());
 459        return;
 460    }
 461
 462    // prepare new instance info
 463    instance_t inst;
 464    inst.meta           = meta;
 465    inst.meta.port      = get_free_port();
 466    inst.meta.status    = SERVER_MODEL_STATUS_LOADING;
 467    inst.meta.last_used = ggml_time_ms();
 468
 469    if (inst.meta.port <= 0) {
 470        throw std::runtime_error("failed to get a port number");
 471    }
 472
 473    inst.subproc = std::make_shared<subprocess_s>();
 474    {
 475        SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
 476
 477        inst.meta.update_args(ctx_preset, bin_path); // render args
 478
 479        std::vector<std::string> child_args = inst.meta.args; // copy
 480        std::vector<std::string> child_env  = base_env; // copy
 481        child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
 482
 483        SRV_INF("%s", "spawning server instance with args:\n");
 484        for (const auto & arg : child_args) {
 485            SRV_INF("  %s\n", arg.c_str());
 486        }
 487        inst.meta.args = child_args; // save for debugging
 488
 489        std::vector<char *> argv = to_char_ptr_array(child_args);
 490        std::vector<char *> envp = to_char_ptr_array(child_env);
 491
 492        // TODO @ngxson : maybe separate stdout and stderr in the future
 493        //                so that we can use stdout for commands and stderr for logging
 494        int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
 495        int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
 496        if (result != 0) {
 497            throw std::runtime_error("failed to spawn server instance");
 498        }
 499
 500        inst.stdin_file = subprocess_stdin(inst.subproc.get());
 501    }
 502
 503    // start a thread to manage the child process
 504    // captured variables are guaranteed to be destroyed only after the thread is joined
 505    inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
 506        FILE * stdin_file = subprocess_stdin(child_proc.get());
 507        FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
 508
 509        std::thread log_thread([&]() {
 510            // read stdout/stderr and forward to main server log
 511            // also handle status report from child process
 512            bool state_received = false; // true if child state received
 513            if (stdout_file) {
 514                char buffer[4096];
 515                while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
 516                    LOG("[%5d] %s", port, buffer);
 517                    if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
 518                        // child process is ready
 519                        this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
 520                        state_received = true;
 521                    }
 522                }
 523            } else {
 524                SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
 525            }
 526        });
 527
 528        std::thread stopping_thread([&]() {
 529            // thread to monitor stopping signal
 530            auto is_stopping = [this, &name]() {
 531                return this->stopping_models.find(name) != this->stopping_models.end();
 532            };
 533            {
 534                std::unique_lock<std::mutex> lk(this->mutex);
 535                this->cv_stop.wait(lk, is_stopping);
 536            }
 537            SRV_INF("stopping model instance name=%s\n", name.c_str());
 538            // send interrupt to child process
 539            fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
 540            fflush(stdin_file);
 541            // wait to stop gracefully or timeout
 542            int64_t start_time = ggml_time_ms();
 543            while (true) {
 544                std::unique_lock<std::mutex> lk(this->mutex);
 545                if (!is_stopping()) {
 546                    return; // already stopped
 547                }
 548                int64_t elapsed = ggml_time_ms() - start_time;
 549                if (elapsed >= stop_timeout * 1000) {
 550                    // timeout, force kill
 551                    SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
 552                    subprocess_terminate(child_proc.get());
 553                    return;
 554                }
 555                this->cv_stop.wait_for(lk, std::chrono::seconds(1));
 556            }
 557        });
 558
 559        // we reach here when the child process exits
 560        // note: we cannot join() prior to this point because it will close stdin_file
 561        if (log_thread.joinable()) {
 562            log_thread.join();
 563        }
 564
 565        // stop the timeout monitoring thread
 566        {
 567            std::lock_guard<std::mutex> lk(this->mutex);
 568            stopping_models.erase(name);
 569            cv_stop.notify_all();
 570        }
 571        if (stopping_thread.joinable()) {
 572            stopping_thread.join();
 573        }
 574
 575        // get the exit code
 576        int exit_code = 0;
 577        subprocess_join(child_proc.get(), &exit_code);
 578        subprocess_destroy(child_proc.get());
 579
 580        // update status and exit code
 581        this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
 582        SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
 583    });
 584
 585    // clean up old process/thread if exists
 586    {
 587        auto & old_instance = mapping[name];
 588        // old process should have exited already, but just in case, we clean it up here
 589        if (subprocess_alive(old_instance.subproc.get())) {
 590            SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
 591            subprocess_terminate(old_instance.subproc.get()); // force kill
 592        }
 593        if (old_instance.th.joinable()) {
 594            old_instance.th.join();
 595        }
 596    }
 597
 598    mapping[name] = std::move(inst);
 599    cv.notify_all();
 600}
 601
 602void server_models::unload(const std::string & name) {
 603    std::lock_guard<std::mutex> lk(mutex);
 604    auto it = mapping.find(name);
 605    if (it != mapping.end()) {
 606        if (it->second.meta.is_active()) {
 607            SRV_INF("unloading model instance name=%s\n", name.c_str());
 608            stopping_models.insert(name);
 609            cv_stop.notify_all();
 610            // status change will be handled by the managing thread
 611        } else {
 612            SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
 613        }
 614    }
 615}
 616
 617void server_models::unload_all() {
 618    std::vector<std::thread> to_join;
 619    {
 620        std::lock_guard<std::mutex> lk(mutex);
 621        for (auto & [name, inst] : mapping) {
 622            if (inst.meta.is_active()) {
 623                SRV_INF("unloading model instance name=%s\n", name.c_str());
 624                stopping_models.insert(name);
 625                cv_stop.notify_all();
 626                // status change will be handled by the managing thread
 627            }
 628            // moving the thread to join list to avoid deadlock
 629            to_join.push_back(std::move(inst.th));
 630        }
 631    }
 632    for (auto & th : to_join) {
 633        if (th.joinable()) {
 634            th.join();
 635        }
 636    }
 637}
 638
 639void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
 640    std::unique_lock<std::mutex> lk(mutex);
 641    auto it = mapping.find(name);
 642    if (it != mapping.end()) {
 643        auto & meta = it->second.meta;
 644        meta.status    = status;
 645        meta.exit_code = exit_code;
 646    }
 647    cv.notify_all();
 648}
 649
 650void server_models::wait_until_loaded(const std::string & name) {
 651    std::unique_lock<std::mutex> lk(mutex);
 652    cv.wait(lk, [this, &name]() {
 653        auto it = mapping.find(name);
 654        if (it != mapping.end()) {
 655            return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
 656        }
 657        return false;
 658    });
 659}
 660
 661bool server_models::ensure_model_loaded(const std::string & name) {
 662    auto meta = get_meta(name);
 663    if (!meta.has_value()) {
 664        throw std::runtime_error("model name=" + name + " is not found");
 665    }
 666    if (meta->status == SERVER_MODEL_STATUS_LOADED) {
 667        return false; // already loaded
 668    }
 669    if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
 670        SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
 671        load(name);
 672    }
 673
 674    // for loading state
 675    SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
 676    wait_until_loaded(name);
 677
 678    // check final status
 679    meta = get_meta(name);
 680    if (!meta.has_value() || meta->is_failed()) {
 681        throw std::runtime_error("model name=" + name + " failed to load");
 682    }
 683
 684    return true;
 685}
 686
 687server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
 688    auto meta = get_meta(name);
 689    if (!meta.has_value()) {
 690        throw std::runtime_error("model name=" + name + " is not found");
 691    }
 692    if (meta->status != SERVER_MODEL_STATUS_LOADED) {
 693        throw std::invalid_argument("model name=" + name + " is not loaded");
 694    }
 695    if (update_last_used) {
 696        std::unique_lock<std::mutex> lk(mutex);
 697        mapping[name].meta.last_used = ggml_time_ms();
 698    }
 699    SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
 700    auto proxy = std::make_unique<server_http_proxy>(
 701            method,
 702            CHILD_ADDR,
 703            meta->port,
 704            req.path,
 705            req.headers,
 706            req.body,
 707            req.should_stop,
 708            base_params.timeout_read,
 709            base_params.timeout_write
 710            );
 711    return proxy;
 712}
 713
 714std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
 715    // send a notification to the router server that a model instance is ready
 716    common_log_pause(common_log_main());
 717    fflush(stdout);
 718    fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
 719    fflush(stdout);
 720    common_log_resume(common_log_main());
 721
 722    // setup thread for monitoring stdin
 723    return std::thread([shutdown_handler]() {
 724        // wait for EOF on stdin
 725        SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
 726        bool eof = false;
 727        while (true) {
 728            std::string line;
 729            if (!std::getline(std::cin, line)) {
 730                // EOF detected, that means the router server is unexpectedly exit or killed
 731                eof = true;
 732                break;
 733            }
 734            if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
 735                SRV_INF("%s", "exit command received, exiting...\n");
 736                shutdown_handler(0);
 737                break;
 738            }
 739        }
 740        if (eof) {
 741            SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
 742            exit(1);
 743        }
 744    });
 745}
 746
 747
 748
 749//
 750// server_models_routes
 751//
 752
 753static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
 754    res->status = 200;
 755    res->data = safe_json_to_str(response_data);
 756}
 757
 758static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
 759    res->status = json_value(error_data, "code", 500);
 760    res->data = safe_json_to_str({{ "error", error_data }});
 761}
 762
 763static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
 764    if (name.empty()) {
 765        res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
 766        return false;
 767    }
 768    auto meta = models.get_meta(name);
 769    if (!meta.has_value()) {
 770        res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST));
 771        return false;
 772    }
 773    if (models_autoload) {
 774        models.ensure_model_loaded(name);
 775    } else {
 776        if (meta->status != SERVER_MODEL_STATUS_LOADED) {
 777            res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
 778            return false;
 779        }
 780    }
 781    return true;
 782}
 783
 784static bool is_autoload(const common_params & params, const server_http_req & req) {
 785    std::string autoload = req.get_param("autoload");
 786    if (autoload.empty()) {
 787        return params.models_autoload;
 788    } else {
 789        return autoload == "true" || autoload == "1";
 790    }
 791}
 792
 793void server_models_routes::init_routes() {
 794    this->get_router_props = [this](const server_http_req & req) {
 795        std::string name = req.get_param("model");
 796        if (name.empty()) {
 797            // main instance
 798            auto res = std::make_unique<server_http_res>();
 799            res_ok(res, {
 800                // TODO: add support for this on web UI
 801                {"role",          "router"},
 802                {"max_instances", 4}, // dummy value for testing
 803                // this is a dummy response to make sure webui doesn't break
 804                {"model_alias", "llama-server"},
 805                {"model_path",  "none"},
 806                {"default_generation_settings", {
 807                    {"params", json{}},
 808                    {"n_ctx",  0},
 809                }},
 810                {"webui_settings", webui_settings},
 811            });
 812            return res;
 813        }
 814        return proxy_get(req);
 815    };
 816
 817    this->proxy_get = [this](const server_http_req & req) {
 818        std::string method = "GET";
 819        std::string name = req.get_param("model");
 820        bool autoload = is_autoload(params, req);
 821        auto error_res = std::make_unique<server_http_res>();
 822        if (!router_validate_model(name, models, autoload, error_res)) {
 823            return error_res;
 824        }
 825        return models.proxy_request(req, method, name, false);
 826    };
 827
 828    this->proxy_post = [this](const server_http_req & req) {
 829        std::string method = "POST";
 830        json body = json::parse(req.body);
 831        std::string name = json_value(body, "model", std::string());
 832        bool autoload = is_autoload(params, req);
 833        auto error_res = std::make_unique<server_http_res>();
 834        if (!router_validate_model(name, models, autoload, error_res)) {
 835            return error_res;
 836        }
 837        return models.proxy_request(req, method, name, true); // update last usage for POST request only
 838    };
 839
 840    this->post_router_models_load = [this](const server_http_req & req) {
 841        auto res = std::make_unique<server_http_res>();
 842        json body = json::parse(req.body);
 843        std::string name = json_value(body, "model", std::string());
 844        auto model = models.get_meta(name);
 845        if (!model.has_value()) {
 846            res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
 847            return res;
 848        }
 849        if (model->status == SERVER_MODEL_STATUS_LOADED) {
 850            res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
 851            return res;
 852        }
 853        models.load(name);
 854        res_ok(res, {{"success", true}});
 855        return res;
 856    };
 857
 858    this->get_router_models = [this](const server_http_req &) {
 859        auto res = std::make_unique<server_http_res>();
 860        json models_json = json::array();
 861        auto all_models = models.get_all_meta();
 862        std::time_t t = std::time(0);
 863        for (const auto & meta : all_models) {
 864            json status {
 865                {"value",  server_model_status_to_string(meta.status)},
 866                {"args",   meta.args},
 867            };
 868            if (!meta.preset.name.empty()) {
 869                common_preset preset_copy = meta.preset;
 870                unset_reserved_args(preset_copy, false);
 871                preset_copy.unset_option("LLAMA_ARG_HOST");
 872                preset_copy.unset_option("LLAMA_ARG_PORT");
 873                preset_copy.unset_option("LLAMA_ARG_ALIAS");
 874                status["preset"] = preset_copy.to_ini();
 875            }
 876            if (meta.is_failed()) {
 877                status["exit_code"] = meta.exit_code;
 878                status["failed"]    = true;
 879            }
 880            models_json.push_back(json {
 881                {"id",       meta.name},
 882                {"object",   "model"},    // for OAI-compat
 883                {"owned_by", "llamacpp"}, // for OAI-compat
 884                {"created",  t},          // for OAI-compat
 885                {"status",   status},
 886                // TODO: add other fields, may require reading GGUF metadata
 887            });
 888        }
 889        res_ok(res, {
 890            {"data", models_json},
 891            {"object", "list"},
 892        });
 893        return res;
 894    };
 895
 896    this->post_router_models_unload = [this](const server_http_req & req) {
 897        auto res = std::make_unique<server_http_res>();
 898        json body = json::parse(req.body);
 899        std::string name = json_value(body, "model", std::string());
 900        auto model = models.get_meta(name);
 901        if (!model.has_value()) {
 902            res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
 903            return res;
 904        }
 905        if (!model->is_active()) {
 906            res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
 907            return res;
 908        }
 909        models.unload(name);
 910        res_ok(res, {{"success", true}});
 911        return res;
 912    };
 913}
 914
 915
 916
 917//
 918// server_http_proxy
 919//
 920
 921// simple implementation of a pipe
 922// used for streaming data between threads
 923template<typename T>
 924struct pipe_t {
 925    std::mutex mutex;
 926    std::condition_variable cv;
 927    std::queue<T> queue;
 928    std::atomic<bool> writer_closed{false};
 929    std::atomic<bool> reader_closed{false};
 930    void close_write() {
 931        writer_closed.store(true, std::memory_order_relaxed);
 932        cv.notify_all();
 933    }
 934    void close_read() {
 935        reader_closed.store(true, std::memory_order_relaxed);
 936        cv.notify_all();
 937    }
 938    bool read(T & output, const std::function<bool()> & should_stop) {
 939        std::unique_lock<std::mutex> lk(mutex);
 940        constexpr auto poll_interval = std::chrono::milliseconds(500);
 941        while (true) {
 942            if (!queue.empty()) {
 943                output = std::move(queue.front());
 944                queue.pop();
 945                return true;
 946            }
 947            if (writer_closed.load()) {
 948                return false; // clean EOF
 949            }
 950            if (should_stop()) {
 951                close_read(); // signal broken pipe to writer
 952                return false; // cancelled / reader no longer alive
 953            }
 954            cv.wait_for(lk, poll_interval);
 955        }
 956    }
 957    bool write(T && data) {
 958        std::lock_guard<std::mutex> lk(mutex);
 959        if (reader_closed.load()) {
 960            return false; // broken pipe
 961        }
 962        queue.push(std::move(data));
 963        cv.notify_one();
 964        return true;
 965    }
 966};
 967
 968static std::string to_lower_copy(const std::string & value) {
 969    std::string lowered(value.size(), '\0');
 970    std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
 971    return lowered;
 972}
 973
 974static bool should_strip_proxy_header(const std::string & header_name) {
 975    // Headers that get duplicated when router forwards child responses
 976    if (header_name == "server" ||
 977        header_name == "transfer-encoding" ||
 978        header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
 979        header_name == "keep-alive") {
 980        return true;
 981    }
 982
 983    // Router injects CORS, child also sends them: duplicate
 984    if (header_name.rfind("access-control-", 0) == 0) {
 985        return true;
 986    }
 987
 988    return false;
 989}
 990
 991server_http_proxy::server_http_proxy(
 992        const std::string & method,
 993        const std::string & host,
 994        int port,
 995        const std::string & path,
 996        const std::map<std::string, std::string> & headers,
 997        const std::string & body,
 998        const std::function<bool()> should_stop,
 999        int32_t timeout_read,
1000        int32_t timeout_write
1001        ) {
1002    // shared between reader and writer threads
1003    auto cli  = std::make_shared<httplib::Client>(host, port);
1004    auto pipe = std::make_shared<pipe_t<msg_t>>();
1005
1006    // setup Client
1007    cli->set_connection_timeout(0, 200000); // 200 milliseconds
1008    cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
1009    cli->set_read_timeout(timeout_write, 0);
1010    this->status = 500; // to be overwritten upon response
1011    this->cleanup = [pipe]() {
1012        pipe->close_read();
1013        pipe->close_write();
1014    };
1015
1016    // wire up the receive end of the pipe
1017    this->next = [pipe, should_stop](std::string & out) -> bool {
1018        msg_t msg;
1019        bool has_next = pipe->read(msg, should_stop);
1020        if (!msg.data.empty()) {
1021            out = std::move(msg.data);
1022        }
1023        return has_next; // false if EOF or pipe broken
1024    };
1025
1026    // wire up the HTTP client
1027    // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
1028    httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
1029        msg_t msg;
1030        msg.status = response.status;
1031        for (const auto & [key, value] : response.headers) {
1032            const auto lowered = to_lower_copy(key);
1033            if (should_strip_proxy_header(lowered)) {
1034                continue;
1035            }
1036            if (lowered == "content-type") {
1037                msg.content_type = value;
1038                continue;
1039            }
1040            msg.headers[key] = value;
1041        }
1042        return pipe->write(std::move(msg)); // send headers first
1043    };
1044    httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
1045        // send data chunks
1046        // returns false if pipe is closed / broken (signal to stop receiving)
1047        return pipe->write({{}, 0, std::string(data, data_length), ""});
1048    };
1049
1050    // prepare the request to destination server
1051    httplib::Request req;
1052    {
1053        req.method = method;
1054        req.path = path;
1055        for (const auto & [key, value] : headers) {
1056            req.set_header(key, value);
1057        }
1058        req.body = body;
1059        req.response_handler = response_handler;
1060        req.content_receiver = content_receiver;
1061    }
1062
1063    // start the proxy thread
1064    SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
1065    this->thread = std::thread([cli, pipe, req]() {
1066        auto result = cli->send(std::move(req));
1067        if (result.error() != httplib::Error::Success) {
1068            auto err_str = httplib::to_string(result.error());
1069            SRV_ERR("http client error: %s\n", err_str.c_str());
1070            pipe->write({{}, 500, "", ""}); // header
1071            pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
1072        }
1073        pipe->close_write(); // signal EOF to reader
1074        SRV_DBG("%s", "client request thread ended\n");
1075    });
1076    this->thread.detach();
1077
1078    // wait for the first chunk (headers)
1079    {
1080        msg_t header;
1081        if (pipe->read(header, should_stop)) {
1082            SRV_DBG("%s", "received response headers\n");
1083            this->status  = header.status;
1084            this->headers = std::move(header.headers);
1085            if (!header.content_type.empty()) {
1086                this->content_type = std::move(header.content_type);
1087            }
1088        } else {
1089            SRV_DBG("%s", "no response headers received (request cancelled?)\n");
1090        }
1091    }
1092}