1#include "arg.h"
  2#include "preset.h"
  3#include "peg-parser.h"
  4#include "log.h"
  5#include "download.h"
  6
  7#include <fstream>
  8#include <sstream>
  9#include <filesystem>
 10
 11static std::string rm_leading_dashes(const std::string & str) {
 12    size_t pos = 0;
 13    while (pos < str.size() && str[pos] == '-') {
 14        ++pos;
 15    }
 16    return str.substr(pos);
 17}
 18
 19// only allow a subset of args for remote presets for security reasons
 20// do not add more args unless absolutely necessary
 21// args that output to files are strictly prohibited
 22static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
 23    static const std::set<std::string> allowed_options = {
 24        "model-url",
 25        "hf-repo",
 26        "hf-repo-draft",
 27        "hf-repo-v", // vocoder
 28        "hf-file-v", // vocoder
 29        "mmproj-url",
 30        "pooling",
 31        "jinja",
 32        "batch-size",
 33        "ubatch-size",
 34        "cache-reuse",
 35        "chat-template-kwargs",
 36        "mmap",
 37        // note: sampling params are automatically allowed by default
 38        // negated args will be added automatically if the positive arg is specified above
 39    };
 40
 41    std::set<std::string> allowed_keys;
 42
 43    for (const auto & it : key_to_opt) {
 44        const std::string & key = it.first;
 45        const common_arg & opt = it.second;
 46        if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
 47            allowed_keys.insert(key);
 48            // also add variant keys (args without leading dashes and env vars)
 49            for (const auto & arg : opt.get_args()) {
 50                allowed_keys.insert(rm_leading_dashes(arg));
 51            }
 52            for (const auto & env : opt.get_env()) {
 53                allowed_keys.insert(env);
 54            }
 55        }
 56    }
 57
 58    return allowed_keys;
 59}
 60
 61std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
 62    std::vector<std::string> args;
 63
 64    if (!bin_path.empty()) {
 65        args.push_back(bin_path);
 66    }
 67
 68    for (const auto & [opt, value] : options) {
 69        if (opt.is_preset_only) {
 70            continue; // skip preset-only options (they are not CLI args)
 71        }
 72
 73        // use the last arg as the main arg (i.e. --long-form)
 74        args.push_back(opt.args.back());
 75
 76        // handle value(s)
 77        if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
 78            // flag option, no value
 79            if (common_arg_utils::is_falsey(value)) {
 80                // use negative arg if available
 81                if (!opt.args_neg.empty()) {
 82                    args.back() = opt.args_neg.back();
 83                } else {
 84                    // otherwise, skip the flag
 85                    // TODO: maybe throw an error instead?
 86                    args.pop_back();
 87                }
 88            }
 89        }
 90        if (opt.value_hint != nullptr) {
 91            // single value
 92            args.push_back(value);
 93        }
 94        if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
 95            throw std::runtime_error(string_format(
 96                "common_preset::to_args(): option '%s' has two values, which is not supported yet",
 97                opt.args.back()
 98            ));
 99        }
100    }
101
102    return args;
103}
104
105std::string common_preset::to_ini() const {
106    std::ostringstream ss;
107
108    ss << "[" << name << "]\n";
109    for (const auto & [opt, value] : options) {
110        auto espaced_value = value;
111        string_replace_all(espaced_value, "\n", "\\\n");
112        ss << rm_leading_dashes(opt.args.back()) << " = ";
113        ss << espaced_value << "\n";
114    }
115    ss << "\n";
116
117    return ss.str();
118}
119
120void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
121    // try if option exists, update it
122    for (auto & [opt, val] : options) {
123        if (opt.env && env == opt.env) {
124            val = value;
125            return;
126        }
127    }
128    // if option does not exist, we need to add it
129    if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
130        throw std::runtime_error(string_format(
131            "%s: option with env '%s' not found in ctx_params",
132            __func__, env.c_str()
133        ));
134    }
135    options[ctx.key_to_opt.at(env)] = value;
136}
137
138void common_preset::unset_option(const std::string & env) {
139    for (auto it = options.begin(); it != options.end(); ) {
140        const common_arg & opt = it->first;
141        if (opt.env && env == opt.env) {
142            it = options.erase(it);
143            return;
144        } else {
145            ++it;
146        }
147    }
148}
149
150bool common_preset::get_option(const std::string & env, std::string & value) const {
151    for (const auto & [opt, val] : options) {
152        if (opt.env && env == opt.env) {
153            value = val;
154            return true;
155        }
156    }
157    return false;
158}
159
160void common_preset::merge(const common_preset & other) {
161    for (const auto & [opt, val] : other.options) {
162        options[opt] = val; // overwrite existing options
163    }
164}
165
166void common_preset::apply_to_params(common_params & params) const {
167    for (const auto & [opt, val] : options) {
168        // apply each option to params
169        if (opt.handler_string) {
170            opt.handler_string(params, val);
171        } else if (opt.handler_int) {
172            opt.handler_int(params, std::stoi(val));
173        } else if (opt.handler_bool) {
174            opt.handler_bool(params, common_arg_utils::is_truthy(val));
175        } else if (opt.handler_str_str) {
176            // not supported yet
177            throw std::runtime_error(string_format(
178                "%s: option with two values is not supported yet",
179                __func__
180            ));
181        } else if (opt.handler_void) {
182            opt.handler_void(params);
183        } else {
184            GGML_ABORT("unknown handler type");
185        }
186    }
187}
188
189static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
190    std::map<std::string, std::map<std::string, std::string>> parsed;
191
192    if (!std::filesystem::exists(path)) {
193        throw std::runtime_error("preset file does not exist: " + path);
194    }
195
196    std::ifstream file(path);
197    if (!file.good()) {
198        throw std::runtime_error("failed to open server preset file: " + path);
199    }
200
201    std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
202
203    static const auto parser = build_peg_parser([](auto & p) {
204        // newline ::= "\r\n" / "\n" / "\r"
205        auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
206
207        // ws ::= [ \t]*
208        auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
209
210        // comment ::= [;#] (!newline .)*
211        auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
212
213        // eol ::= ws comment? (newline / EOF)
214        auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
215
216        // ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
217        auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
218
219        // value ::= (!eol-start .)*
220        auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
221        auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
222
223        // header-line ::= "[" ws ident ws "]" eol
224        auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
225
226        // kv-line ::= ident ws "=" ws value eol
227        auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
228
229        // comment-line ::= ws comment (newline / EOF)
230        auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
231
232        // blank-line ::= ws (newline / EOF)
233        auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
234
235        // line ::= header-line / kv-line / comment-line / blank-line
236        auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
237
238        // ini ::= line* EOF
239        auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
240
241        return ini;
242    });
243
244    common_peg_parse_context ctx(contents);
245    const auto result = parser.parse(ctx);
246    if (!result.success()) {
247        throw std::runtime_error("failed to parse server config file: " + path);
248    }
249
250    std::string current_section = COMMON_PRESET_DEFAULT_NAME;
251    std::string current_key;
252
253    ctx.ast.visit(result, [&](const auto & node) {
254        if (node.tag == "section-name") {
255            const std::string section = std::string(node.text);
256            current_section = section;
257            parsed[current_section] = {};
258        } else if (node.tag == "key") {
259            const std::string key = std::string(node.text);
260            current_key = key;
261        } else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
262            parsed[current_section][current_key] = std::string(node.text);
263            current_key.clear();
264        }
265    });
266
267    return parsed;
268}
269
270static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
271    std::map<std::string, common_arg> mapping;
272    for (const auto & opt : ctx_params.options) {
273        for (const auto & env : opt.get_env()) {
274            mapping[env] = opt;
275        }
276        for (const auto & arg : opt.get_args()) {
277            mapping[rm_leading_dashes(arg)] = opt;
278        }
279    }
280    return mapping;
281}
282
283static bool is_bool_arg(const common_arg & arg) {
284    return !arg.args_neg.empty();
285}
286
287static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
288    // if this is a negated arg, we need to reverse the value
289    for (const auto & neg_arg : arg.args_neg) {
290        if (rm_leading_dashes(neg_arg) == key) {
291            return common_arg_utils::is_truthy(value) ? "false" : "true";
292        }
293    }
294    // otherwise, not negated
295    return value;
296}
297
298common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
299        : ctx_params(common_params_parser_init(default_params, ex)) {
300    common_params_add_preset_options(ctx_params.options);
301    key_to_opt = get_map_key_opt(ctx_params);
302
303    // setup allowed keys if only_remote_allowed is true
304    if (only_remote_allowed) {
305        filter_allowed_keys = true;
306        allowed_keys = get_remote_preset_whitelist(key_to_opt);
307    }
308}
309
310common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
311    common_presets out;
312    auto ini_data = parse_ini_from_file(path);
313
314    for (auto section : ini_data) {
315        common_preset preset;
316        if (section.first.empty()) {
317            preset.name = COMMON_PRESET_DEFAULT_NAME;
318        } else {
319            preset.name = section.first;
320        }
321        LOG_DBG("loading preset: %s\n", preset.name.c_str());
322        for (const auto & [key, value] : section.second) {
323            if (key == "version") {
324                // skip version key (reserved for future use)
325                continue;
326            }
327
328            LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
329            if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
330                throw std::runtime_error(string_format(
331                    "option '%s' is not allowed in remote presets",
332                    key.c_str()
333                ));
334            }
335            if (key_to_opt.find(key) != key_to_opt.end()) {
336                const auto & opt = key_to_opt.at(key);
337                if (is_bool_arg(opt)) {
338                    preset.options[opt] = parse_bool_arg(opt, key, value);
339                } else {
340                    preset.options[opt] = value;
341                }
342                LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
343            } else {
344                throw std::runtime_error(string_format(
345                    "option '%s' not recognized in preset '%s'",
346                    key.c_str(), preset.name.c_str()
347                ));
348            }
349        }
350
351        if (preset.name == "*") {
352            // handle global preset
353            global = preset;
354        } else {
355            out[preset.name] = preset;
356        }
357    }
358
359    return out;
360}
361
362common_presets common_preset_context::load_from_cache() const {
363    common_presets out;
364
365    auto cached_models = common_list_cached_models();
366    for (const auto & model : cached_models) {
367        common_preset preset;
368        preset.name = model.to_string();
369        preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
370        out[preset.name] = preset;
371    }
372
373    return out;
374}
375
376struct local_model {
377    std::string name;
378    std::string path;
379    std::string path_mmproj;
380};
381
382common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
383    if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
384        throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
385    }
386
387    std::vector<local_model> models;
388    auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
389        auto files = fs_list(subdir_path, false);
390        common_file_info model_file;
391        common_file_info first_shard_file;
392        common_file_info mmproj_file;
393        for (const auto & file : files) {
394            if (string_ends_with(file.name, ".gguf")) {
395                if (file.name.find("mmproj") != std::string::npos) {
396                    mmproj_file = file;
397                } else if (file.name.find("-00001-of-") != std::string::npos) {
398                    first_shard_file = file;
399                } else {
400                    model_file = file;
401                }
402            }
403        }
404        // single file model
405        local_model model{
406            /* name        */ name,
407            /* path        */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
408            /* path_mmproj */ mmproj_file.path // can be empty
409        };
410        if (!model.path.empty()) {
411            models.push_back(model);
412        }
413    };
414
415    auto files = fs_list(models_dir, true);
416    for (const auto & file : files) {
417        if (file.is_dir) {
418            scan_subdir(file.path, file.name);
419        } else if (string_ends_with(file.name, ".gguf")) {
420            // single file model
421            std::string name = file.name;
422            string_replace_all(name, ".gguf", "");
423            local_model model{
424                /* name        */ name,
425                /* path        */ file.path,
426                /* path_mmproj */ ""
427            };
428            models.push_back(model);
429        }
430    }
431
432    // convert local models to presets
433    common_presets out;
434    for (const auto & model : models) {
435        common_preset preset;
436        preset.name = model.name;
437        preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
438        if (!model.path_mmproj.empty()) {
439            preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
440        }
441        out[preset.name] = preset;
442    }
443
444    return out;
445}
446
447common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
448    common_preset preset;
449    preset.name = COMMON_PRESET_DEFAULT_NAME;
450
451    bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
452    if (!ok) {
453        throw std::runtime_error("failed to parse CLI arguments into preset");
454    }
455
456    return preset;
457}
458
459common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
460    common_presets out = base; // copy
461    for (const auto & [name, preset_added] : added) {
462        if (out.find(name) != out.end()) {
463            // if exists, merge
464            common_preset & target = out[name];
465            target.merge(preset_added);
466        } else {
467            // otherwise, add directly
468            out[name] = preset_added;
469        }
470    }
471    return out;
472}
473
474common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
475    common_presets out;
476    for (const auto & [name, preset] : presets) {
477        common_preset tmp = base; // copy
478        tmp.name = name;
479        tmp.merge(preset);
480        out[name] = std::move(tmp);
481    }
482    return out;
483}