1#include "llama-grammar.h"
   2
   3#include "llama-impl.h"
   4#include "llama-vocab.h"
   5#include "llama-sampler.h"
   6
   7#include <cmath>
   8#include <algorithm>
   9#include <cstdint>
  10#include <stdexcept>
  11
  12#define MAX_REPETITION_THRESHOLD 2000
  13//
  14// helpers
  15//
  16
  17// NOTE: assumes valid utf8 (but checks for overrun)
  18static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
  19    static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  20    uint8_t  first_byte = static_cast<uint8_t>(*src);
  21    uint8_t  highbits   = first_byte >> 4;
  22    int      len        = lookup[highbits];
  23    uint8_t  mask       = (1 << (8 - len)) - 1;
  24    uint32_t value      = first_byte & mask;
  25    const char * end    = src + len; // may overrun!
  26    const char * pos    = src + 1;
  27    for ( ; pos < end && *pos; pos++) {
  28        value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  29    }
  30    return std::make_pair(value, pos);
  31}
  32
  33static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
  34        const std::string & src,
  35        llama_partial_utf8 partial_start) {
  36    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
  37    const char          * pos      = src.c_str();
  38    std::vector<uint32_t> code_points;
  39
  40    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
  41    code_points.reserve(src.size() + 1);
  42    uint32_t value    = partial_start.value;
  43    int      n_remain = partial_start.n_remain;
  44
  45    // continue previous decode, if applicable
  46    while (*pos != 0 && n_remain > 0) {
  47        uint8_t next_byte = static_cast<uint8_t>(*pos);
  48        if ((next_byte >> 6) != 2) {
  49            // invalid sequence, abort
  50            code_points.push_back(0);
  51            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
  52        }
  53        value = (value << 6) + (next_byte & 0x3F);
  54        ++pos;
  55        --n_remain;
  56    }
  57
  58    if (partial_start.n_remain > 0 && n_remain == 0) {
  59        code_points.push_back(value);
  60    }
  61
  62    // decode any subsequent utf-8 sequences, which may end in an incomplete one
  63    while (*pos != 0) {
  64        uint8_t first_byte = static_cast<uint8_t>(*pos);
  65        uint8_t highbits   = first_byte >> 4;
  66        n_remain   = lookup[highbits] - 1;
  67
  68        if (n_remain < 0) {
  69            // invalid sequence, abort
  70            code_points.clear();
  71            code_points.push_back(0);
  72            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
  73        }
  74
  75        uint8_t mask  = (1 << (7 - n_remain)) - 1;
  76        value = first_byte & mask;
  77
  78        ++pos;
  79        while (*pos != 0 && n_remain > 0) {
  80            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  81            ++pos;
  82            --n_remain;
  83        }
  84        if (n_remain == 0) {
  85            code_points.push_back(value);
  86        }
  87    }
  88    code_points.push_back(0);
  89
  90    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
  91}
  92
  93static bool is_digit_char(char c) {
  94    return '0' <= c && c <= '9';
  95}
  96
  97static bool is_word_char(char c) {
  98    return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
  99}
 100
 101static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
 102    const char * pos   = src;
 103    const char * end   = src + size;
 104    uint32_t     value = 0;
 105    for ( ; pos < end && *pos; pos++) {
 106        value <<= 4;
 107        char c = *pos;
 108        if ('a' <= c && c <= 'f') {
 109            value += c - 'a' + 10;
 110        } else if ('A' <= c && c <= 'F') {
 111            value += c - 'A' + 10;
 112        } else if ('0' <= c && c <= '9') {
 113            value += c - '0';
 114        } else {
 115            break;
 116        }
 117    }
 118    if (pos != end) {
 119        throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
 120    }
 121    return std::make_pair(value, pos);
 122}
 123
 124static const char * parse_space(const char * src, bool newline_ok) {
 125    const char * pos = src;
 126    while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
 127            (newline_ok && (*pos == '\r' || *pos == '\n'))) {
 128        if (*pos == '#') {
 129            while (*pos && *pos != '\r' && *pos != '\n') {
 130                pos++;
 131            }
 132        } else {
 133            pos++;
 134        }
 135    }
 136    return pos;
 137}
 138
 139static const char * parse_name(const char * src) {
 140    const char * pos = src;
 141    while (is_word_char(*pos)) {
 142        pos++;
 143    }
 144    if (pos == src) {
 145        throw std::runtime_error(std::string("expecting name at ") + src);
 146    }
 147    return pos;
 148}
 149
 150static const char * parse_int(const char * src) {
 151    const char * pos = src;
 152    while (is_digit_char(*pos)) {
 153        pos++;
 154    }
 155    if (pos == src) {
 156        throw std::runtime_error(std::string("expecting integer at ") + src);
 157    }
 158    return pos;
 159}
 160
 161static std::pair<uint32_t, const char *> parse_char(const char * src) {
 162    if (*src == '\\') {
 163        switch (src[1]) {
 164            case 'x': return parse_hex(src + 2, 2);
 165            case 'u': return parse_hex(src + 2, 4);
 166            case 'U': return parse_hex(src + 2, 8);
 167            case 't': return std::make_pair('\t', src + 2);
 168            case 'r': return std::make_pair('\r', src + 2);
 169            case 'n': return std::make_pair('\n', src + 2);
 170            case '\\':
 171            case '"':
 172            case '[':
 173            case ']':
 174                      return std::make_pair(src[1], src + 2);
 175            default:
 176                      throw std::runtime_error(std::string("unknown escape at ") + src);
 177        }
 178    } else if (*src) {
 179        return decode_utf8(src);
 180    }
 181    throw std::runtime_error("unexpected end of input");
 182}
 183
 184static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
 185    const char * pos = src;
 186    if (*pos != '<') {
 187        throw std::runtime_error(std::string("expecting '<' at ") + pos);
 188    }
 189    pos++;
 190
 191    // Parse <[id]>
 192    if (*pos == '[') {
 193        pos++;
 194        const char * int_end = parse_int(pos);
 195        uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
 196        pos = int_end;
 197        if (*pos != ']') {
 198            throw std::runtime_error(std::string("expecting ']' at ") + pos);
 199        }
 200        pos++;
 201        if (*pos != '>') {
 202            throw std::runtime_error(std::string("expecting '>' at ") + pos);
 203        }
 204        pos++;
 205        return std::make_pair(token_id, pos);
 206    }
 207
 208    if (vocab == nullptr) {
 209        throw std::runtime_error(std::string("no vocab to parse token at ") + src);
 210    }
 211
 212    // Parse <token> and tokenize to obtain the token id
 213    while (*pos != 0 && *pos != '>') {
 214        pos++;
 215    }
 216    if (*pos != '>') {
 217        throw std::runtime_error(std::string("expecting '>' at ") + pos);
 218    }
 219    pos++;
 220
 221    llama_token tokens[2];
 222    int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
 223    if (n_tokens != 1) {
 224        // must tokenize to exactly 1 token
 225        throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
 226    }
 227    return std::make_pair(tokens[0], pos);
 228}
 229
 230static void print_grammar_char(FILE * file, uint32_t c) {
 231    if (0x20 <= c && c <= 0x7f) {
 232        fprintf(file, "%c", static_cast<char>(c));
 233    } else {
 234        // cop out of encoding UTF-8
 235        fprintf(file, "<U+%04X>", c);
 236    }
 237}
 238
 239static bool is_char_element(llama_grammar_element elem) {
 240    switch (elem.type) {
 241        case LLAMA_GRETYPE_CHAR:           return true;
 242        case LLAMA_GRETYPE_CHAR_NOT:       return true;
 243        case LLAMA_GRETYPE_CHAR_ALT:       return true;
 244        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
 245        case LLAMA_GRETYPE_CHAR_ANY:       return true;
 246        default:                           return false;
 247    }
 248}
 249
 250static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
 251    for (auto elem : rule) {
 252        switch (elem.type) {
 253            case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
 254            case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
 255            case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
 256            case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
 257            case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
 258            case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
 259            case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
 260            case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
 261            case LLAMA_GRETYPE_TOKEN:          fprintf(file, "TOKEN");          break;
 262            case LLAMA_GRETYPE_TOKEN_NOT:      fprintf(file, "TOKEN_NOT");      break;
 263        }
 264        switch (elem.type) {
 265            case LLAMA_GRETYPE_END:
 266            case LLAMA_GRETYPE_ALT:
 267            case LLAMA_GRETYPE_RULE_REF:
 268                fprintf(file, "(%u) ", elem.value);
 269                break;
 270            case LLAMA_GRETYPE_CHAR:
 271            case LLAMA_GRETYPE_CHAR_NOT:
 272            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
 273            case LLAMA_GRETYPE_CHAR_ALT:
 274            case LLAMA_GRETYPE_CHAR_ANY:
 275                fprintf(file, "(\"");
 276                print_grammar_char(file, elem.value);
 277                fprintf(file, "\") ");
 278                break;
 279            case LLAMA_GRETYPE_TOKEN:
 280                fprintf(file, "<[");
 281                fprintf(file, "%u", elem.value);
 282                fprintf(file, "]> ");
 283                break;
 284            case LLAMA_GRETYPE_TOKEN_NOT:
 285                fprintf(file, "!");
 286                fprintf(file, "<[");
 287                fprintf(file, "%u", elem.value);
 288                fprintf(file, "]> ");
 289                break;
 290        }
 291    }
 292    fprintf(file, "\n");
 293}
 294
 295static void print_rule(
 296        FILE     * file,
 297        uint32_t   rule_id,
 298        const llama_grammar_rule & rule,
 299        const std::map<uint32_t, std::string> & symbol_id_names) {
 300    if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
 301        throw std::runtime_error(
 302            "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
 303    }
 304    fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
 305    for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
 306        llama_grammar_element elem = rule[i];
 307        switch (elem.type) {
 308            case LLAMA_GRETYPE_END:
 309                throw std::runtime_error(
 310                    "unexpected end of rule: " + std::to_string(rule_id) + "," +
 311                    std::to_string(i));
 312            case LLAMA_GRETYPE_ALT:
 313                fprintf(file, "| ");
 314                break;
 315            case LLAMA_GRETYPE_RULE_REF:
 316                fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
 317                break;
 318            case LLAMA_GRETYPE_CHAR:
 319                fprintf(file, "[");
 320                print_grammar_char(file, elem.value);
 321                break;
 322            case LLAMA_GRETYPE_CHAR_NOT:
 323                fprintf(file, "[^");
 324                print_grammar_char(file, elem.value);
 325                break;
 326            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
 327                if (i == 0 || !is_char_element(rule[i - 1])) {
 328                    throw std::runtime_error(
 329                        "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
 330                        std::to_string(rule_id) + "," + std::to_string(i));
 331                }
 332                fprintf(file, "-");
 333                print_grammar_char(file, elem.value);
 334                break;
 335            case LLAMA_GRETYPE_CHAR_ALT:
 336                if (i == 0 || !is_char_element(rule[i - 1])) {
 337                    throw std::runtime_error(
 338                        "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
 339                        std::to_string(rule_id) + "," + std::to_string(i));
 340                }
 341                print_grammar_char(file, elem.value);
 342                break;
 343            case LLAMA_GRETYPE_CHAR_ANY:
 344                fprintf(file, ".");
 345                break;
 346            case LLAMA_GRETYPE_TOKEN:
 347                fprintf(file, "<[");
 348                fprintf(file, "%u", elem.value);
 349                fprintf(file, "]> ");
 350                break;
 351            case LLAMA_GRETYPE_TOKEN_NOT:
 352                fprintf(file, "!");
 353                fprintf(file, "<[");
 354                fprintf(file, "%u", elem.value);
 355                fprintf(file, "]> ");
 356                break;
 357        }
 358        if (is_char_element(elem)) {
 359            switch (rule[i + 1].type) {
 360                case LLAMA_GRETYPE_CHAR_ALT:
 361                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
 362                case LLAMA_GRETYPE_CHAR_ANY:
 363                    break;
 364                default:
 365                    fprintf(file, "] ");
 366            }
 367        }
 368    }
 369    fprintf(file, "\n");
 370}
 371
 372//
 373// Regex utilities
 374//
 375
 376size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
 377    auto find_start_pos = [](const std::smatch & match) {
 378        // get from the first matched capturing group to the end of the string
 379        size_t start = std::string::npos;
 380        for (auto i = 1u; i < match.size(); i++) {
 381            if (match.length(i) > 0) {
 382                start = match.position(i);
 383                break;
 384            }
 385        }
 386        if (start == std::string::npos) {
 387            start = match.position(0);
 388        }
 389        return start;
 390    };
 391
 392    if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
 393        // match against the entire input
 394        std::smatch match;
 395        if (std::regex_match(input, match, regex)) {
 396            return find_start_pos(match);
 397        }
 398    }
 399
 400    // search anywhere
 401    std::smatch match;
 402    if (std::regex_search(input, match, regex)) {
 403        return find_start_pos(match);
 404    }
 405
 406    return std::string::npos;
 407}
 408
 409
 410//
 411// implementation
 412//
 413
 414uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
 415    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
 416    auto result = symbol_ids.emplace(std::string(src, len), next_id);
 417    return result.first->second;
 418}
 419
 420uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
 421    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
 422    symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
 423    return next_id;
 424}
 425
 426void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
 427    if (rules.size() <= rule_id) {
 428        rules.resize(rule_id + 1);
 429    }
 430    rules[rule_id] = rule;
 431}
 432
 433const char * llama_grammar_parser::parse_alternates(
 434        const char        * src,
 435        const std::string & rule_name,
 436        uint32_t            rule_id,
 437        bool                is_nested) {
 438    llama_grammar_rule rule;
 439    const char * pos = parse_sequence(src, rule_name, rule, is_nested);
 440    while (*pos == '|') {
 441        rule.push_back({LLAMA_GRETYPE_ALT, 0});
 442        pos = parse_space(pos + 1, true);
 443        pos = parse_sequence(pos, rule_name, rule, is_nested);
 444    }
 445    rule.push_back({LLAMA_GRETYPE_END, 0});
 446    add_rule(rule_id, rule);
 447    return pos;
 448}
 449
 450const char * llama_grammar_parser::parse_sequence(
 451        const char         * src,
 452        const std::string  & rule_name,
 453        llama_grammar_rule & rule,
 454        bool               is_nested) {
 455    size_t last_sym_start = rule.size();
 456    const char * pos = src;
 457
 458    // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
 459    // (though it's technically the same as -1 now)
 460    auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
 461        bool no_max = max_times == UINT64_MAX;
 462        if (last_sym_start == rule.size()) {
 463            throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
 464        }
 465
 466        // apply transformation to previous symbol (last_sym_start to end) according to
 467        // the following rewrite rules:
 468        // S{m,n} --> S S S (m times) S'(n-m)
 469        //            S'(x)   ::= S S'(x-1) |
 470        //            (... n-m definitions of these S' rules ...)
 471        //            S'(1)   ::= S |
 472        // S{m,} -->  S S S (m times) S'
 473        //            S'     ::= S S' |
 474        // S*     --> S{0,}
 475        //        --> S'     ::= S S' |
 476        // S+     --> S{1,}
 477        //        --> S S'
 478        //            S'     ::= S S' |
 479        // S?     --> S{0,1}
 480        //        --> S'
 481        //            S'     ::= S |
 482
 483        llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
 484        if (min_times == 0) {
 485            rule.resize(last_sym_start);
 486        } else {
 487            // Repeat the previous elements (min_times - 1) times
 488            for (uint64_t i = 1; i < min_times; i++) {
 489                rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
 490            }
 491        }
 492
 493        uint32_t last_rec_rule_id = 0;
 494        auto n_opt = no_max ? 1 : max_times - min_times;
 495
 496        llama_grammar_rule rec_rule(prev_rule);
 497        for (uint64_t i = 0; i < n_opt; i++) {
 498            rec_rule.resize(prev_rule.size());
 499            uint32_t rec_rule_id = generate_symbol_id( rule_name);
 500            if (i > 0 || no_max) {
 501                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
 502            }
 503            rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
 504            rec_rule.push_back({LLAMA_GRETYPE_END, 0});
 505            add_rule( rec_rule_id, rec_rule);
 506            last_rec_rule_id = rec_rule_id;
 507        }
 508        if (n_opt > 0) {
 509            rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
 510        }
 511    };
 512
 513    while (*pos) {
 514        if (*pos == '"') { // literal string
 515            pos++;
 516            last_sym_start = rule.size();
 517            while (*pos != '"') {
 518                if (!*pos) {
 519                    throw std::runtime_error("unexpected end of input");
 520                }
 521                auto char_pair = parse_char(pos);
 522                     pos       = char_pair.second;
 523                rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
 524            }
 525            pos = parse_space(pos + 1, is_nested);
 526        } else if (*pos == '[') { // char range(s)
 527            pos++;
 528            enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
 529            if (*pos == '^') {
 530                pos++;
 531                start_type = LLAMA_GRETYPE_CHAR_NOT;
 532            }
 533            last_sym_start = rule.size();
 534            while (*pos != ']') {
 535                if (!*pos) {
 536                    throw std::runtime_error("unexpected end of input");
 537                }
 538                auto char_pair = parse_char(pos);
 539                     pos       = char_pair.second;
 540                enum llama_gretype type = last_sym_start < rule.size()
 541                    ? LLAMA_GRETYPE_CHAR_ALT
 542                    : start_type;
 543
 544                rule.push_back({type, char_pair.first});
 545                if (pos[0] == '-' && pos[1] != ']') {
 546                    if (!pos[1]) {
 547                        throw std::runtime_error("unexpected end of input");
 548                    }
 549                    auto endchar_pair = parse_char(pos + 1);
 550                         pos          = endchar_pair.second;
 551                    rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
 552                }
 553            }
 554            pos = parse_space(pos + 1, is_nested);
 555        } else if (*pos == '<' || *pos == '!') { // token
 556            auto type = LLAMA_GRETYPE_TOKEN;
 557            if (*pos == '!') { // token inverse
 558                type = LLAMA_GRETYPE_TOKEN_NOT;
 559                pos++;
 560            }
 561            auto token_pair = parse_token(vocab, pos);
 562            const char * token_end  = token_pair.second;
 563            last_sym_start = rule.size();
 564            rule.push_back({type, token_pair.first});
 565            pos = parse_space(token_end, is_nested);
 566        } else if (is_word_char(*pos)) { // rule reference
 567            const char * name_end    = parse_name(pos);
 568            uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
 569            pos = parse_space(name_end, is_nested);
 570            last_sym_start = rule.size();
 571            rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
 572        } else if (*pos == '(') { // grouping
 573            // parse nested alternates into synthesized rule
 574            pos = parse_space(pos + 1, true);
 575            uint32_t sub_rule_id = generate_symbol_id(rule_name);
 576            pos = parse_alternates(pos, rule_name, sub_rule_id, true);
 577            last_sym_start = rule.size();
 578            // output reference to synthesized rule
 579            rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
 580            if (*pos != ')') {
 581                throw std::runtime_error(std::string("expecting ')' at ") + pos);
 582            }
 583            pos = parse_space(pos + 1, is_nested);
 584        } else if (*pos == '.') { // any char
 585            last_sym_start = rule.size();
 586            rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
 587            pos = parse_space(pos + 1, is_nested);
 588        } else if (*pos == '*') {
 589            pos = parse_space(pos + 1, is_nested);
 590            handle_repetitions(0, -1);
 591        } else if (*pos == '+') {
 592            pos = parse_space(pos + 1, is_nested);
 593            handle_repetitions(1, -1);
 594        } else if (*pos == '?') {
 595            pos = parse_space(pos + 1, is_nested);
 596            handle_repetitions(0, 1);
 597        } else if (*pos == '{') {
 598            pos = parse_space(pos + 1, is_nested);
 599
 600            if (!is_digit_char(*pos)) {
 601                throw std::runtime_error(std::string("expecting an int at ") + pos);
 602            }
 603            const char * int_end = parse_int(pos);
 604            uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
 605            pos = parse_space(int_end, is_nested);
 606
 607            uint64_t max_times = UINT64_MAX; // default: no max limit
 608
 609            if (*pos == '}') {
 610                max_times = min_times;
 611                pos = parse_space(pos + 1, is_nested);
 612            } else if (*pos == ',') {
 613                pos = parse_space(pos + 1, is_nested);
 614
 615                if (is_digit_char(*pos)) {
 616                    const char * int_end = parse_int(pos);
 617                    max_times = std::stoul(std::string(pos, int_end - pos));
 618                    pos = parse_space(int_end, is_nested);
 619                }
 620
 621                if (*pos != '}') {
 622                    throw std::runtime_error(std::string("expecting '}' at ") + pos);
 623                }
 624                pos = parse_space(pos + 1, is_nested);
 625            } else {
 626                throw std::runtime_error(std::string("expecting ',' at ") + pos);
 627            }
 628            bool has_max = max_times != UINT64_MAX;
 629            if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
 630                throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
 631            }
 632            handle_repetitions(min_times, max_times);
 633        } else {
 634            break;
 635        }
 636    }
 637    return pos;
 638}
 639
 640const char * llama_grammar_parser::parse_rule(const char * src) {
 641    const char * name_end = parse_name(src);
 642    const char * pos      = parse_space(name_end, false);
 643    size_t       name_len = name_end - src;
 644    uint32_t     rule_id  = get_symbol_id(src, name_len);
 645    const std::string name(src, name_len);
 646
 647    if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
 648        throw std::runtime_error(std::string("expecting ::= at ") + pos);
 649    }
 650    pos = parse_space(pos + 3, true);
 651
 652    pos = parse_alternates(pos, name, rule_id, false);
 653
 654    if (*pos == '\r') {
 655        pos += pos[1] == '\n' ? 2 : 1;
 656    } else if (*pos == '\n') {
 657        pos++;
 658    } else if (*pos) {
 659        throw std::runtime_error(std::string("expecting newline or end at ") + pos);
 660    }
 661    return parse_space(pos, true);
 662}
 663
 664bool llama_grammar_parser::parse(const char * src) {
 665    try {
 666        const char * pos = parse_space(src, true);
 667        while (*pos) {
 668            pos = parse_rule(pos);
 669        }
 670        // Validate the state to ensure that all rules are defined
 671        for (const auto & rule : rules) {
 672            if (rule.empty()) {
 673                throw std::runtime_error("Undefined rule");
 674            }
 675            for (const auto & elem : rule) {
 676                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
 677                    // Ensure that the rule at that location exists
 678                    if (elem.value >= rules.size() || rules[elem.value].empty()) {
 679                        // Get the name of the rule that is missing
 680                        for (const auto & kv : symbol_ids) {
 681                            if (kv.second == elem.value) {
 682                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
 683                            }
 684                        }
 685                    }
 686                }
 687            }
 688        }
 689    } catch (const std::exception & err) {
 690        fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
 691        rules.clear();
 692        return false;
 693    }
 694
 695    return true;
 696}
 697
 698void llama_grammar_parser::print(FILE * file) {
 699    try {
 700        std::map<uint32_t, std::string> symbol_id_names;
 701        for (const auto & kv : symbol_ids) {
 702            symbol_id_names[kv.second] = kv.first;
 703        }
 704        for (size_t i = 0, end = rules.size(); i < end; i++) {
 705            // fprintf(file, "%zu: ", i);
 706            // print_rule_binary(file, rules[i]);
 707            print_rule(file, uint32_t(i), rules[i], symbol_id_names);
 708            // fprintf(file, "\n");
 709        }
 710    } catch (const std::exception & err) {
 711        fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
 712    }
 713}
 714
 715llama_grammar_stack llama_grammar_parser::c_rules() const {
 716    llama_grammar_stack ret;
 717    ret.reserve(rules.size());
 718    for (const auto & rule : rules) {
 719        ret.push_back(rule.data());
 720    }
 721    return ret;
 722}
 723
 724// returns true iff pos points to the end of one of the definitions of a rule
 725static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
 726    switch (pos->type) {
 727        case LLAMA_GRETYPE_END: return true;  // NOLINT
 728        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
 729        default:                return false;
 730    }
 731}
 732
 733// returns true iff chr satisfies the char range at pos (regular or inverse range)
 734// asserts that pos is pointing to a char range element
 735static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
 736        const llama_grammar_element * pos,
 737        const uint32_t                chr) {
 738    bool found            = false;
 739    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 740
 741    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
 742
 743    do {
 744        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
 745            // inclusive range, e.g. [a-z]
 746            found = found || (pos->value <= chr && chr <= pos[1].value);
 747            pos += 2;
 748        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
 749            // Any character matches "."
 750            found = true;
 751            pos += 1;
 752        } else {
 753            // exact char match, e.g. [a] or "a"
 754            found = found || pos->value == chr;
 755            pos += 1;
 756        }
 757    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
 758
 759    return std::make_pair(found == is_positive_char, pos);
 760}
 761
 762// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
 763// range at pos (regular or inverse range)
 764// asserts that pos is pointing to a char range element
 765static bool llama_grammar_match_partial_char(
 766        const llama_grammar_element * pos,
 767        const llama_partial_utf8      partial_utf8) {
 768    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 769    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
 770
 771    uint32_t partial_value = partial_utf8.value;
 772    int      n_remain      = partial_utf8.n_remain;
 773
 774    // invalid sequence or 7-bit char split across 2 bytes (overlong)
 775    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
 776        return false;
 777    }
 778
 779    // range of possible code points this partial UTF-8 sequence could complete to
 780    uint32_t low  = partial_value << (n_remain * 6);
 781    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
 782
 783    if (low == 0) {
 784        if (n_remain == 2) {
 785            low = 1 << 11;
 786        } else if (n_remain == 3) {
 787            low = 1 << 16;
 788        }
 789    }
 790
 791    do {
 792        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
 793            // inclusive range, e.g. [a-z]
 794            if (pos->value <= high && low <= pos[1].value) {
 795                return is_positive_char;
 796            }
 797            pos += 2;
 798        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
 799            // Any character matches "."
 800            return true;
 801        } else {
 802            // exact char match, e.g. [a] or "a"
 803            if (low <= pos->value && pos->value <= high) {
 804                return is_positive_char;
 805            }
 806            pos += 1;
 807        }
 808    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
 809
 810    return !is_positive_char;
 811}
 812
 813// returns true iff token matches the rule at pos (regular or inverse)
 814// asserts that pos is pointing to a token element
 815static bool llama_grammar_match_token(
 816    const llama_grammar_element * pos,
 817    const llama_token             token) {
 818    GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
 819    if (pos->type == LLAMA_GRETYPE_TOKEN) {
 820        return pos->value == static_cast<uint32_t>(token);
 821    }
 822    if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
 823        return pos->value != static_cast<uint32_t>(token);
 824    }
 825    return false;
 826}
 827
 828// transforms a grammar pushdown stack into N possible stacks, all ending
 829// at a character range (terminal element)
 830static void llama_grammar_advance_stack(
 831        const llama_grammar_rules  & rules,
 832        const llama_grammar_stack  & stack,
 833              llama_grammar_stacks & new_stacks) {
 834    if (stack.empty()) {
 835        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
 836            new_stacks.emplace_back(stack);
 837        }
 838        return;
 839    }
 840
 841    const llama_grammar_element * pos = stack.back();
 842
 843    switch (pos->type) {
 844        case LLAMA_GRETYPE_RULE_REF: {
 845            const size_t                  rule_id = static_cast<size_t>(pos->value);
 846            const llama_grammar_element * subpos  = rules[rule_id].data();
 847            do {
 848                // init new stack without the top (pos)
 849                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
 850                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
 851                    // if this rule ref is followed by another element, add that to stack
 852                    new_stack.push_back(pos + 1);
 853                }
 854                if (!llama_grammar_is_end_of_sequence(subpos)) {
 855                    // if alternate is nonempty, add to stack
 856                    new_stack.push_back(subpos);
 857                }
 858                llama_grammar_advance_stack(rules, new_stack, new_stacks);
 859                while (!llama_grammar_is_end_of_sequence(subpos)) {
 860                    // scan to end of alternate def
 861                    subpos++;
 862                }
 863                if (subpos->type == LLAMA_GRETYPE_ALT) {
 864                    // there's another alternate def of this rule to process
 865                    subpos++;
 866                } else {
 867                    break;
 868                }
 869            } while (true);
 870            break;
 871        }
 872        case LLAMA_GRETYPE_CHAR:
 873        case LLAMA_GRETYPE_CHAR_NOT:
 874        case LLAMA_GRETYPE_CHAR_ANY:
 875        case LLAMA_GRETYPE_TOKEN:
 876        case LLAMA_GRETYPE_TOKEN_NOT:
 877            if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
 878                // only add the stack if it's not a duplicate of one we already have
 879                new_stacks.emplace_back(stack);
 880            }
 881            break;
 882        default:
 883            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
 884            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
 885            // those
 886            GGML_ABORT("fatal error");
 887    }
 888}
 889
 890static llama_grammar_candidates llama_grammar_reject_candidates(
 891        const llama_grammar_rules      & rules,
 892        const llama_grammar_stacks     & stacks,
 893        const llama_grammar_candidates & candidates) {
 894    GGML_ASSERT(!stacks.empty()); // REVIEW
 895
 896    if (candidates.empty()) {
 897        return {};
 898    }
 899
 900    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
 901
 902    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
 903        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
 904    }
 905
 906    return rejects;
 907}
 908
 909static bool llama_grammar_detect_left_recursion(
 910        const llama_grammar_rules & rules,
 911        size_t rule_index,
 912        std::vector<bool> * rules_visited,
 913        std::vector<bool> * rules_in_progress,
 914        std::vector<bool> * rules_may_be_empty) {
 915    if ((*rules_in_progress)[rule_index]) {
 916        return true;
 917    }
 918
 919    (*rules_in_progress)[rule_index] = true;
 920
 921    const llama_grammar_rule & rule = rules[rule_index];
 922
 923    // First check if the rule might produce the empty string. This could be done combined with the second
 924    // step but it's more readable as two steps.
 925    bool at_rule_start = true;
 926    for (size_t i = 0; i < rule.size(); i++) {
 927        if (llama_grammar_is_end_of_sequence(&rule[i])) {
 928            if (at_rule_start) {
 929                (*rules_may_be_empty)[rule_index] = true;
 930                break;
 931            }
 932            at_rule_start = true;
 933        } else {
 934            at_rule_start = false;
 935        }
 936    }
 937
 938    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
 939    // be empty)
 940    bool recurse_into_nonterminal = true;
 941    for (size_t i = 0; i < rule.size(); i++) {
 942        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
 943            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
 944                return true;
 945            }
 946            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
 947                recurse_into_nonterminal = false;
 948            }
 949        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
 950            recurse_into_nonterminal = true;
 951        } else {
 952            recurse_into_nonterminal = false;
 953        }
 954    }
 955
 956    (*rules_in_progress)[rule_index] = false;
 957    (*rules_visited)[rule_index] = true;
 958
 959    return false;
 960}
 961
 962const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
 963    return grammar->rules;
 964}
 965
 966llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
 967    return grammar->stacks;
 968}
 969
 970static void llama_grammar_accept_chr(
 971        struct llama_grammar       & grammar,
 972        const llama_grammar_stack  & stack,
 973              uint32_t               chr,
 974              llama_grammar_stacks & new_stacks) {
 975    if (stack.empty()) {
 976        return;
 977    }
 978
 979    const llama_grammar_element * pos = stack.back();
 980
 981    // ignore if this turns into a token
 982    if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
 983        return;
 984    }
 985
 986    auto match = llama_grammar_match_char(pos, chr);
 987    if (match.first) {
 988        llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
 989        if (!llama_grammar_is_end_of_sequence(match.second)) {
 990            new_stack.push_back(match.second);
 991        }
 992        llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
 993    }
 994}
 995
 996void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
 997    llama_grammar_stacks stacks_new;
 998    stacks_new.reserve(grammar->stacks.size());
 999
1000    for (const auto & stack : grammar->stacks) {
1001        llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
1002    }
1003
1004    grammar->stacks = std::move(stacks_new);
1005}
1006
1007llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
1008        const llama_grammar_rules      & rules,
1009        const llama_grammar_stack      & stack,
1010        const llama_grammar_candidates & candidates) {
1011
1012    llama_grammar_candidates rejects;
1013    rejects.reserve(candidates.size());
1014
1015    if (stack.empty()) {
1016        for (const auto & tok : candidates) {
1017            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
1018                rejects.push_back(tok);
1019            }
1020        }
1021        return rejects;
1022    }
1023
1024    const llama_grammar_element * stack_pos = stack.back();
1025
1026    // if the top of the stack is a token rule, then we only need to check the token id
1027    if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1028        for (const auto & tok : candidates) {
1029            if (*tok.code_points == 0) {
1030                // reached the end of a token consumed by char rules, reject iff it ended
1031                // in a partial response
1032                if (tok.partial_utf8.n_remain != 0) {
1033                    rejects.push_back(tok);
1034                }
1035            } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
1036                rejects.push_back(tok);
1037            }
1038        }
1039        return rejects;
1040    }
1041
1042    llama_grammar_candidates next_candidates;
1043    next_candidates.reserve(candidates.size());
1044
1045    for (const auto & tok : candidates) {
1046        if (*tok.code_points == 0) {
1047            // reached end of full codepoints in token, reject iff it ended in a partial sequence
1048            // that cannot satisfy this position in grammar
1049            if (tok.partial_utf8.n_remain != 0 &&
1050                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
1051                rejects.push_back(tok);
1052            }
1053        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
1054            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
1055        } else {
1056            rejects.push_back(tok);
1057        }
1058    }
1059
1060    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
1061
1062    // update top of stack to next element, if any
1063    llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
1064    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
1065        stack_after.push_back(stack_pos_after);
1066    }
1067    llama_grammar_stacks next_stacks;
1068    llama_grammar_advance_stack(rules, stack_after, next_stacks);
1069
1070    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
1071    for (const auto & tok : next_rejects) {
1072        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
1073    }
1074
1075    return rejects;
1076}
1077
1078////////////////////
1079
1080struct llama_grammar * llama_grammar_init_impl(
1081        const struct llama_vocab * vocab,
1082        const llama_grammar_element ** rules,
1083        size_t n_rules,
1084        size_t start_rule_index) {
1085    const llama_grammar_element * pos;
1086
1087    // copy rule definitions into vectors
1088    llama_grammar_rules vec_rules(n_rules);
1089    for (size_t i = 0; i < n_rules; i++) {
1090        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1091            vec_rules[i].push_back(*pos);
1092        }
1093        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1094    }
1095
1096    // Check for left recursion
1097    std::vector<bool> rules_visited(n_rules);
1098    std::vector<bool> rules_in_progress(n_rules);
1099    std::vector<bool> rules_may_be_empty(n_rules);
1100    for (size_t i = 0; i < n_rules; i++) {
1101        if (rules_visited[i]) {
1102            continue;
1103        }
1104        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1105            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1106            return nullptr;
1107        }
1108    }
1109
1110    // loop over alternates of start rule to build initial stacks
1111    llama_grammar_stacks stacks;
1112    pos = vec_rules[start_rule_index].data();
1113    do {
1114        llama_grammar_stack stack;
1115        if (!llama_grammar_is_end_of_sequence(pos)) {
1116            // if alternate is nonempty, add to stack
1117            stack.push_back(pos);
1118        }
1119        llama_grammar_advance_stack(vec_rules, stack, stacks);
1120        while (!llama_grammar_is_end_of_sequence(pos)) {
1121            // scan to end of alternate def
1122            pos++;
1123        }
1124        if (pos->type == LLAMA_GRETYPE_ALT) {
1125            // there's another alternate def of this rule to process
1126            pos++;
1127        } else {
1128            break;
1129        }
1130    } while (true);
1131
1132    // Important: vec_rules has to be moved here, not copied, because stacks contains
1133    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1134    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1135    return new llama_grammar {
1136        vocab,
1137        std::move(vec_rules),
1138        std::move(stacks),
1139        /* .partial_utf8 = */             {},
1140        /* .lazy = */                     false,
1141        /* .awaiting_trigger = */         false,
1142        /* .trigger_buffer = */           "",
1143        /* .trigger_buffer_positions = */ {},
1144        /* .trigger_tokens = */           {},
1145        /* .trigger_patterns = */         {},
1146    };
1147}
1148
1149struct llama_grammar * llama_grammar_init_impl(
1150        const struct llama_vocab * vocab,
1151                      const char * grammar_str,
1152                      const char * grammar_root,
1153                              bool lazy,
1154                     const char ** trigger_patterns,
1155                            size_t num_trigger_patterns,
1156               const llama_token * trigger_tokens,
1157                            size_t num_trigger_tokens) {
1158    llama_grammar_parser parser(vocab);
1159
1160    // if there is a grammar, parse it
1161    // rules will be empty (default) if there are parse errors
1162    if (!parser.parse(grammar_str) || parser.rules.empty()) {
1163        fprintf(stderr, "%s: failed to parse grammar\n", __func__);
1164        return nullptr;
1165    }
1166
1167    // Ensure that there is a "root" node.
1168    if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
1169        fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
1170        return nullptr;
1171    }
1172
1173    std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
1174
1175    const size_t n_rules = grammar_rules.size();
1176    const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
1177
1178    const llama_grammar_element * pos;
1179
1180    // copy rule definitions into vectors
1181    llama_grammar_rules vec_rules(n_rules);
1182    for (size_t i = 0; i < n_rules; i++) {
1183        for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1184            vec_rules[i].push_back(*pos);
1185        }
1186        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1187    }
1188
1189    // Check for left recursion
1190    std::vector<bool> rules_visited(n_rules);
1191    std::vector<bool> rules_in_progress(n_rules);
1192    std::vector<bool> rules_may_be_empty(n_rules);
1193    for (size_t i = 0; i < n_rules; i++) {
1194        if (rules_visited[i]) {
1195            continue;
1196        }
1197        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1198            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1199            return nullptr;
1200        }
1201    }
1202
1203    // loop over alternates of start rule to build initial stacks
1204    llama_grammar_stacks stacks;
1205    pos = vec_rules[start_rule_index].data();
1206    do {
1207        llama_grammar_stack stack;
1208        if (!llama_grammar_is_end_of_sequence(pos)) {
1209            // if alternate is nonempty, add to stack
1210            stack.push_back(pos);
1211        }
1212        llama_grammar_advance_stack(vec_rules, stack, stacks);
1213        while (!llama_grammar_is_end_of_sequence(pos)) {
1214            // scan to end of alternate def
1215            pos++;
1216        }
1217        if (pos->type == LLAMA_GRETYPE_ALT) {
1218            // there's another alternate def of this rule to process
1219            pos++;
1220        } else {
1221            break;
1222        }
1223    } while (true);
1224
1225    std::vector<llama_token>    vec_trigger_tokens;
1226    std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
1227    for (size_t i = 0; i < num_trigger_tokens; i++) {
1228        GGML_ASSERT(trigger_tokens != nullptr);
1229        vec_trigger_tokens.push_back(trigger_tokens[i]);
1230    }
1231    for (size_t i = 0; i < num_trigger_patterns; i++) {
1232        GGML_ASSERT(trigger_patterns != nullptr);
1233        auto & trigger = vec_trigger_patterns.emplace_back();
1234        trigger.pattern = trigger_patterns[i];
1235        trigger.regex = std::regex(trigger.pattern);
1236    }
1237
1238    // Important: vec_rules has to be moved here, not copied, because stacks contains
1239    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1240    // then the pointers would be invalidated when the local vec_rules goes out of scope.
1241    return new llama_grammar {
1242        vocab,
1243        std::move(vec_rules),
1244        std::move(stacks),
1245        /* .partial_utf8 = */             {},
1246        /* .lazy = */                     lazy,
1247        /* .awaiting_trigger = */         lazy,
1248        /* .trigger_buffer = */           "",
1249        /* .trigger_buffer_positions = */ {},
1250        std::move(vec_trigger_tokens),
1251        std::move(vec_trigger_patterns),
1252    };
1253}
1254
1255void llama_grammar_free_impl(struct llama_grammar * grammar) {
1256    if (grammar == nullptr) {
1257        return;
1258    }
1259
1260    delete grammar;
1261}
1262
1263struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1264    auto * result = new llama_grammar {
1265        grammar.vocab,
1266        grammar.rules,
1267        grammar.stacks,
1268        grammar.partial_utf8,
1269        grammar.lazy,
1270        grammar.awaiting_trigger,
1271        grammar.trigger_buffer,
1272        grammar.trigger_buffer_positions,
1273        grammar.trigger_tokens,
1274        grammar.trigger_patterns,
1275    };
1276
1277    // redirect elements in stacks to point to new rules
1278    for (size_t is = 0; is < result->stacks.size(); is++) {
1279        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
1280            for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1281                for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1282                    if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1283                        result->stacks[is][ie] =  &result->rules[ir0][ir1];
1284                    }
1285                }
1286            }
1287        }
1288    }
1289
1290    return result;
1291}
1292
1293void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1294    GGML_ASSERT(grammar.vocab != nullptr);
1295
1296    if (grammar.awaiting_trigger) {
1297        return;
1298    }
1299
1300    bool allow_eog = false;
1301    for (const auto & stack : grammar.stacks) {
1302        if (stack.empty()) {
1303            allow_eog = true;
1304            break;
1305        }
1306    }
1307
1308    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1309    candidates_decoded.reserve(cur_p->size);
1310
1311    llama_grammar_candidates candidates_grammar;
1312    candidates_grammar.reserve(cur_p->size);
1313
1314    for (size_t i = 0; i < cur_p->size; ++i) {
1315        const llama_token id      = cur_p->data[i].id;
1316        const std::string & piece = grammar.vocab->token_to_piece(id);
1317
1318        if (grammar.vocab->is_eog(id)) {
1319            if (!allow_eog) {
1320                cur_p->data[i].logit = -INFINITY;
1321            }
1322        } else if (piece.empty() || piece[0] == 0) {
1323            cur_p->data[i].logit = -INFINITY;
1324        } else {
1325            candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1326            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
1327        }
1328    }
1329
1330    const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
1331    for (const auto & reject : rejects) {
1332        cur_p->data[reject.index].logit = -INFINITY;
1333    }
1334}
1335
1336void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1337    GGML_ASSERT(grammar.vocab != nullptr);
1338
1339    const auto & piece = grammar.vocab->token_to_piece(token);
1340
1341    if (grammar.awaiting_trigger) {
1342        if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1343            grammar.awaiting_trigger = false;
1344            grammar.trigger_buffer.clear();
1345            llama_grammar_accept_token(grammar, token, piece);
1346            LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1347            return;
1348        } else {
1349            auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
1350            grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
1351            grammar.trigger_buffer += piece;
1352
1353            for (const auto & trigger_pattern : grammar.trigger_patterns) {
1354                auto start = trigger_pattern.find(grammar.trigger_buffer);
1355                if (start != std::string::npos) {
1356                    grammar.awaiting_trigger = false;
1357
1358                    // replay tokens that overlap with [start, end)
1359                    for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
1360                        auto [tok_start, tok_end] = tok_pos;
1361                        if (tok_end <= start) {
1362                            continue;
1363                        }
1364
1365                        size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
1366                        size_t piece_len = tok_end - piece_start;
1367                        auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
1368                        llama_grammar_accept_token(grammar, tok, tok_piece);
1369                    }
1370
1371                    auto constrained_str = grammar.trigger_buffer.substr(start);
1372                    grammar.trigger_buffer.clear();
1373                    grammar.trigger_buffer_positions.clear();
1374                    LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1375                    return;
1376                }
1377            }
1378            LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
1379            return;
1380        }
1381    }
1382
1383    if (grammar.vocab->is_eog(token)) {
1384        for (const auto & stack : grammar.stacks) {
1385            if (stack.empty()) {
1386                return;
1387            }
1388        }
1389        GGML_ABORT("fatal error");
1390    }
1391
1392    llama_grammar_accept_token(grammar, token, piece);
1393}
1394
1395void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
1396    // Note terminating 0 in decoded string
1397    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1398    const auto & code_points = decoded.first;
1399
1400    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1401        llama_grammar_accept(&grammar, *it);
1402    }
1403
1404    grammar.partial_utf8 = decoded.second;
1405    if (grammar.stacks.empty()) {
1406        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1407    }
1408}
1409
1410void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
1411    // Note terminating 0 in decoded string
1412    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
1413    const auto & code_points = decoded.first;
1414
1415    llama_grammar_stacks stacks_new;
1416    stacks_new.reserve(grammar.stacks.size());
1417
1418    for (const auto & stack : grammar.stacks) {
1419        if (stack.empty()) {
1420            continue;
1421        }
1422
1423        const llama_grammar_element * pos = stack.back();
1424
1425        if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1426            if (llama_grammar_match_token(pos, token)) {
1427                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1428                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
1429                    new_stack.push_back(pos + 1);
1430                }
1431                llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
1432            }
1433        } else {
1434            llama_grammar_stacks current_stacks = {stack};
1435
1436            for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1437                llama_grammar_stacks next_stacks;
1438
1439                for (const auto & cur_stack : current_stacks) {
1440                    llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
1441                }
1442
1443                current_stacks = std::move(next_stacks);
1444                if (current_stacks.empty()) {
1445                    break;
1446                }
1447            }
1448
1449            for (auto & surviving_stack : current_stacks) {
1450                if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
1451                    stacks_new.emplace_back(surviving_stack);
1452                }
1453            }
1454        }
1455    }
1456
1457    grammar.stacks = std::move(stacks_new);
1458    grammar.partial_utf8 = decoded.second;
1459
1460    if (grammar.stacks.empty()) {
1461        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
1462    }
1463}
1464