summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-grammar.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/src/llama-grammar.cpp')
-rw-r--r--llama.cpp/src/llama-grammar.cpp1464
1 files changed, 1464 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-grammar.cpp b/llama.cpp/src/llama-grammar.cpp
new file mode 100644
index 0000000..2d55070
--- /dev/null
+++ b/llama.cpp/src/llama-grammar.cpp
@@ -0,0 +1,1464 @@
1#include "llama-grammar.h"
2
3#include "llama-impl.h"
4#include "llama-vocab.h"
5#include "llama-sampler.h"
6
7#include <cmath>
8#include <algorithm>
9#include <cstdint>
10#include <stdexcept>
11
12#define MAX_REPETITION_THRESHOLD 2000
13//
14// helpers
15//
16
17// NOTE: assumes valid utf8 (but checks for overrun)
18static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
19 static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
20 uint8_t first_byte = static_cast<uint8_t>(*src);
21 uint8_t highbits = first_byte >> 4;
22 int len = lookup[highbits];
23 uint8_t mask = (1 << (8 - len)) - 1;
24 uint32_t value = first_byte & mask;
25 const char * end = src + len; // may overrun!
26 const char * pos = src + 1;
27 for ( ; pos < end && *pos; pos++) {
28 value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
29 }
30 return std::make_pair(value, pos);
31}
32
33static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
34 const std::string & src,
35 llama_partial_utf8 partial_start) {
36 static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
37 const char * pos = src.c_str();
38 std::vector<uint32_t> code_points;
39
40 // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
41 code_points.reserve(src.size() + 1);
42 uint32_t value = partial_start.value;
43 int n_remain = partial_start.n_remain;
44
45 // continue previous decode, if applicable
46 while (*pos != 0 && n_remain > 0) {
47 uint8_t next_byte = static_cast<uint8_t>(*pos);
48 if ((next_byte >> 6) != 2) {
49 // invalid sequence, abort
50 code_points.push_back(0);
51 return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
52 }
53 value = (value << 6) + (next_byte & 0x3F);
54 ++pos;
55 --n_remain;
56 }
57
58 if (partial_start.n_remain > 0 && n_remain == 0) {
59 code_points.push_back(value);
60 }
61
62 // decode any subsequent utf-8 sequences, which may end in an incomplete one
63 while (*pos != 0) {
64 uint8_t first_byte = static_cast<uint8_t>(*pos);
65 uint8_t highbits = first_byte >> 4;
66 n_remain = lookup[highbits] - 1;
67
68 if (n_remain < 0) {
69 // invalid sequence, abort
70 code_points.clear();
71 code_points.push_back(0);
72 return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
73 }
74
75 uint8_t mask = (1 << (7 - n_remain)) - 1;
76 value = first_byte & mask;
77
78 ++pos;
79 while (*pos != 0 && n_remain > 0) {
80 value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
81 ++pos;
82 --n_remain;
83 }
84 if (n_remain == 0) {
85 code_points.push_back(value);
86 }
87 }
88 code_points.push_back(0);
89
90 return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
91}
92
93static bool is_digit_char(char c) {
94 return '0' <= c && c <= '9';
95}
96
97static bool is_word_char(char c) {
98 return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
99}
100
101static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
102 const char * pos = src;
103 const char * end = src + size;
104 uint32_t value = 0;
105 for ( ; pos < end && *pos; pos++) {
106 value <<= 4;
107 char c = *pos;
108 if ('a' <= c && c <= 'f') {
109 value += c - 'a' + 10;
110 } else if ('A' <= c && c <= 'F') {
111 value += c - 'A' + 10;
112 } else if ('0' <= c && c <= '9') {
113 value += c - '0';
114 } else {
115 break;
116 }
117 }
118 if (pos != end) {
119 throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
120 }
121 return std::make_pair(value, pos);
122}
123
124static const char * parse_space(const char * src, bool newline_ok) {
125 const char * pos = src;
126 while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
127 (newline_ok && (*pos == '\r' || *pos == '\n'))) {
128 if (*pos == '#') {
129 while (*pos && *pos != '\r' && *pos != '\n') {
130 pos++;
131 }
132 } else {
133 pos++;
134 }
135 }
136 return pos;
137}
138
139static const char * parse_name(const char * src) {
140 const char * pos = src;
141 while (is_word_char(*pos)) {
142 pos++;
143 }
144 if (pos == src) {
145 throw std::runtime_error(std::string("expecting name at ") + src);
146 }
147 return pos;
148}
149
150static const char * parse_int(const char * src) {
151 const char * pos = src;
152 while (is_digit_char(*pos)) {
153 pos++;
154 }
155 if (pos == src) {
156 throw std::runtime_error(std::string("expecting integer at ") + src);
157 }
158 return pos;
159}
160
161static std::pair<uint32_t, const char *> parse_char(const char * src) {
162 if (*src == '\\') {
163 switch (src[1]) {
164 case 'x': return parse_hex(src + 2, 2);
165 case 'u': return parse_hex(src + 2, 4);
166 case 'U': return parse_hex(src + 2, 8);
167 case 't': return std::make_pair('\t', src + 2);
168 case 'r': return std::make_pair('\r', src + 2);
169 case 'n': return std::make_pair('\n', src + 2);
170 case '\\':
171 case '"':
172 case '[':
173 case ']':
174 return std::make_pair(src[1], src + 2);
175 default:
176 throw std::runtime_error(std::string("unknown escape at ") + src);
177 }
178 } else if (*src) {
179 return decode_utf8(src);
180 }
181 throw std::runtime_error("unexpected end of input");
182}
183
184static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
185 const char * pos = src;
186 if (*pos != '<') {
187 throw std::runtime_error(std::string("expecting '<' at ") + pos);
188 }
189 pos++;
190
191 // Parse <[id]>
192 if (*pos == '[') {
193 pos++;
194 const char * int_end = parse_int(pos);
195 uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
196 pos = int_end;
197 if (*pos != ']') {
198 throw std::runtime_error(std::string("expecting ']' at ") + pos);
199 }
200 pos++;
201 if (*pos != '>') {
202 throw std::runtime_error(std::string("expecting '>' at ") + pos);
203 }
204 pos++;
205 return std::make_pair(token_id, pos);
206 }
207
208 if (vocab == nullptr) {
209 throw std::runtime_error(std::string("no vocab to parse token at ") + src);
210 }
211
212 // Parse <token> and tokenize to obtain the token id
213 while (*pos != 0 && *pos != '>') {
214 pos++;
215 }
216 if (*pos != '>') {
217 throw std::runtime_error(std::string("expecting '>' at ") + pos);
218 }
219 pos++;
220
221 llama_token tokens[2];
222 int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
223 if (n_tokens != 1) {
224 // must tokenize to exactly 1 token
225 throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
226 }
227 return std::make_pair(tokens[0], pos);
228}
229
230static void print_grammar_char(FILE * file, uint32_t c) {
231 if (0x20 <= c && c <= 0x7f) {
232 fprintf(file, "%c", static_cast<char>(c));
233 } else {
234 // cop out of encoding UTF-8
235 fprintf(file, "<U+%04X>", c);
236 }
237}
238
239static bool is_char_element(llama_grammar_element elem) {
240 switch (elem.type) {
241 case LLAMA_GRETYPE_CHAR: return true;
242 case LLAMA_GRETYPE_CHAR_NOT: return true;
243 case LLAMA_GRETYPE_CHAR_ALT: return true;
244 case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
245 case LLAMA_GRETYPE_CHAR_ANY: return true;
246 default: return false;
247 }
248}
249
250static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
251 for (auto elem : rule) {
252 switch (elem.type) {
253 case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
254 case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
255 case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
256 case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
257 case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
258 case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
259 case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
260 case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
261 case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
262 case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
263 }
264 switch (elem.type) {
265 case LLAMA_GRETYPE_END:
266 case LLAMA_GRETYPE_ALT:
267 case LLAMA_GRETYPE_RULE_REF:
268 fprintf(file, "(%u) ", elem.value);
269 break;
270 case LLAMA_GRETYPE_CHAR:
271 case LLAMA_GRETYPE_CHAR_NOT:
272 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
273 case LLAMA_GRETYPE_CHAR_ALT:
274 case LLAMA_GRETYPE_CHAR_ANY:
275 fprintf(file, "(\"");
276 print_grammar_char(file, elem.value);
277 fprintf(file, "\") ");
278 break;
279 case LLAMA_GRETYPE_TOKEN:
280 fprintf(file, "<[");
281 fprintf(file, "%u", elem.value);
282 fprintf(file, "]> ");
283 break;
284 case LLAMA_GRETYPE_TOKEN_NOT:
285 fprintf(file, "!");
286 fprintf(file, "<[");
287 fprintf(file, "%u", elem.value);
288 fprintf(file, "]> ");
289 break;
290 }
291 }
292 fprintf(file, "\n");
293}
294
295static void print_rule(
296 FILE * file,
297 uint32_t rule_id,
298 const llama_grammar_rule & rule,
299 const std::map<uint32_t, std::string> & symbol_id_names) {
300 if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
301 throw std::runtime_error(
302 "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
303 }
304 fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
305 for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
306 llama_grammar_element elem = rule[i];
307 switch (elem.type) {
308 case LLAMA_GRETYPE_END:
309 throw std::runtime_error(
310 "unexpected end of rule: " + std::to_string(rule_id) + "," +
311 std::to_string(i));
312 case LLAMA_GRETYPE_ALT:
313 fprintf(file, "| ");
314 break;
315 case LLAMA_GRETYPE_RULE_REF:
316 fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
317 break;
318 case LLAMA_GRETYPE_CHAR:
319 fprintf(file, "[");
320 print_grammar_char(file, elem.value);
321 break;
322 case LLAMA_GRETYPE_CHAR_NOT:
323 fprintf(file, "[^");
324 print_grammar_char(file, elem.value);
325 break;
326 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
327 if (i == 0 || !is_char_element(rule[i - 1])) {
328 throw std::runtime_error(
329 "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
330 std::to_string(rule_id) + "," + std::to_string(i));
331 }
332 fprintf(file, "-");
333 print_grammar_char(file, elem.value);
334 break;
335 case LLAMA_GRETYPE_CHAR_ALT:
336 if (i == 0 || !is_char_element(rule[i - 1])) {
337 throw std::runtime_error(
338 "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
339 std::to_string(rule_id) + "," + std::to_string(i));
340 }
341 print_grammar_char(file, elem.value);
342 break;
343 case LLAMA_GRETYPE_CHAR_ANY:
344 fprintf(file, ".");
345 break;
346 case LLAMA_GRETYPE_TOKEN:
347 fprintf(file, "<[");
348 fprintf(file, "%u", elem.value);
349 fprintf(file, "]> ");
350 break;
351 case LLAMA_GRETYPE_TOKEN_NOT:
352 fprintf(file, "!");
353 fprintf(file, "<[");
354 fprintf(file, "%u", elem.value);
355 fprintf(file, "]> ");
356 break;
357 }
358 if (is_char_element(elem)) {
359 switch (rule[i + 1].type) {
360 case LLAMA_GRETYPE_CHAR_ALT:
361 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
362 case LLAMA_GRETYPE_CHAR_ANY:
363 break;
364 default:
365 fprintf(file, "] ");
366 }
367 }
368 }
369 fprintf(file, "\n");
370}
371
372//
373// Regex utilities
374//
375
376size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
377 auto find_start_pos = [](const std::smatch & match) {
378 // get from the first matched capturing group to the end of the string
379 size_t start = std::string::npos;
380 for (auto i = 1u; i < match.size(); i++) {
381 if (match.length(i) > 0) {
382 start = match.position(i);
383 break;
384 }
385 }
386 if (start == std::string::npos) {
387 start = match.position(0);
388 }
389 return start;
390 };
391
392 if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
393 // match against the entire input
394 std::smatch match;
395 if (std::regex_match(input, match, regex)) {
396 return find_start_pos(match);
397 }
398 }
399
400 // search anywhere
401 std::smatch match;
402 if (std::regex_search(input, match, regex)) {
403 return find_start_pos(match);
404 }
405
406 return std::string::npos;
407}
408
409
410//
411// implementation
412//
413
414uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
415 uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
416 auto result = symbol_ids.emplace(std::string(src, len), next_id);
417 return result.first->second;
418}
419
420uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
421 uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
422 symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
423 return next_id;
424}
425
426void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
427 if (rules.size() <= rule_id) {
428 rules.resize(rule_id + 1);
429 }
430 rules[rule_id] = rule;
431}
432
433const char * llama_grammar_parser::parse_alternates(
434 const char * src,
435 const std::string & rule_name,
436 uint32_t rule_id,
437 bool is_nested) {
438 llama_grammar_rule rule;
439 const char * pos = parse_sequence(src, rule_name, rule, is_nested);
440 while (*pos == '|') {
441 rule.push_back({LLAMA_GRETYPE_ALT, 0});
442 pos = parse_space(pos + 1, true);
443 pos = parse_sequence(pos, rule_name, rule, is_nested);
444 }
445 rule.push_back({LLAMA_GRETYPE_END, 0});
446 add_rule(rule_id, rule);
447 return pos;
448}
449
450const char * llama_grammar_parser::parse_sequence(
451 const char * src,
452 const std::string & rule_name,
453 llama_grammar_rule & rule,
454 bool is_nested) {
455 size_t last_sym_start = rule.size();
456 const char * pos = src;
457
458 // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
459 // (though it's technically the same as -1 now)
460 auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
461 bool no_max = max_times == UINT64_MAX;
462 if (last_sym_start == rule.size()) {
463 throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
464 }
465
466 // apply transformation to previous symbol (last_sym_start to end) according to
467 // the following rewrite rules:
468 // S{m,n} --> S S S (m times) S'(n-m)
469 // S'(x) ::= S S'(x-1) |
470 // (... n-m definitions of these S' rules ...)
471 // S'(1) ::= S |
472 // S{m,} --> S S S (m times) S'
473 // S' ::= S S' |
474 // S* --> S{0,}
475 // --> S' ::= S S' |
476 // S+ --> S{1,}
477 // --> S S'
478 // S' ::= S S' |
479 // S? --> S{0,1}
480 // --> S'
481 // S' ::= S |
482
483 llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
484 if (min_times == 0) {
485 rule.resize(last_sym_start);
486 } else {
487 // Repeat the previous elements (min_times - 1) times
488 for (uint64_t i = 1; i < min_times; i++) {
489 rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
490 }
491 }
492
493 uint32_t last_rec_rule_id = 0;
494 auto n_opt = no_max ? 1 : max_times - min_times;
495
496 llama_grammar_rule rec_rule(prev_rule);
497 for (uint64_t i = 0; i < n_opt; i++) {
498 rec_rule.resize(prev_rule.size());
499 uint32_t rec_rule_id = generate_symbol_id( rule_name);
500 if (i > 0 || no_max) {
501 rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
502 }
503 rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
504 rec_rule.push_back({LLAMA_GRETYPE_END, 0});
505 add_rule( rec_rule_id, rec_rule);
506 last_rec_rule_id = rec_rule_id;
507 }
508 if (n_opt > 0) {
509 rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
510 }
511 };
512
513 while (*pos) {
514 if (*pos == '"') { // literal string
515 pos++;
516 last_sym_start = rule.size();
517 while (*pos != '"') {
518 if (!*pos) {
519 throw std::runtime_error("unexpected end of input");
520 }
521 auto char_pair = parse_char(pos);
522 pos = char_pair.second;
523 rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
524 }
525 pos = parse_space(pos + 1, is_nested);
526 } else if (*pos == '[') { // char range(s)
527 pos++;
528 enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
529 if (*pos == '^') {
530 pos++;
531 start_type = LLAMA_GRETYPE_CHAR_NOT;
532 }
533 last_sym_start = rule.size();
534 while (*pos != ']') {
535 if (!*pos) {
536 throw std::runtime_error("unexpected end of input");
537 }
538 auto char_pair = parse_char(pos);
539 pos = char_pair.second;
540 enum llama_gretype type = last_sym_start < rule.size()
541 ? LLAMA_GRETYPE_CHAR_ALT
542 : start_type;
543
544 rule.push_back({type, char_pair.first});
545 if (pos[0] == '-' && pos[1] != ']') {
546 if (!pos[1]) {
547 throw std::runtime_error("unexpected end of input");
548 }
549 auto endchar_pair = parse_char(pos + 1);
550 pos = endchar_pair.second;
551 rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
552 }
553 }
554 pos = parse_space(pos + 1, is_nested);
555 } else if (*pos == '<' || *pos == '!') { // token
556 auto type = LLAMA_GRETYPE_TOKEN;
557 if (*pos == '!') { // token inverse
558 type = LLAMA_GRETYPE_TOKEN_NOT;
559 pos++;
560 }
561 auto token_pair = parse_token(vocab, pos);
562 const char * token_end = token_pair.second;
563 last_sym_start = rule.size();
564 rule.push_back({type, token_pair.first});
565 pos = parse_space(token_end, is_nested);
566 } else if (is_word_char(*pos)) { // rule reference
567 const char * name_end = parse_name(pos);
568 uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
569 pos = parse_space(name_end, is_nested);
570 last_sym_start = rule.size();
571 rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
572 } else if (*pos == '(') { // grouping
573 // parse nested alternates into synthesized rule
574 pos = parse_space(pos + 1, true);
575 uint32_t sub_rule_id = generate_symbol_id(rule_name);
576 pos = parse_alternates(pos, rule_name, sub_rule_id, true);
577 last_sym_start = rule.size();
578 // output reference to synthesized rule
579 rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
580 if (*pos != ')') {
581 throw std::runtime_error(std::string("expecting ')' at ") + pos);
582 }
583 pos = parse_space(pos + 1, is_nested);
584 } else if (*pos == '.') { // any char
585 last_sym_start = rule.size();
586 rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
587 pos = parse_space(pos + 1, is_nested);
588 } else if (*pos == '*') {
589 pos = parse_space(pos + 1, is_nested);
590 handle_repetitions(0, -1);
591 } else if (*pos == '+') {
592 pos = parse_space(pos + 1, is_nested);
593 handle_repetitions(1, -1);
594 } else if (*pos == '?') {
595 pos = parse_space(pos + 1, is_nested);
596 handle_repetitions(0, 1);
597 } else if (*pos == '{') {
598 pos = parse_space(pos + 1, is_nested);
599
600 if (!is_digit_char(*pos)) {
601 throw std::runtime_error(std::string("expecting an int at ") + pos);
602 }
603 const char * int_end = parse_int(pos);
604 uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
605 pos = parse_space(int_end, is_nested);
606
607 uint64_t max_times = UINT64_MAX; // default: no max limit
608
609 if (*pos == '}') {
610 max_times = min_times;
611 pos = parse_space(pos + 1, is_nested);
612 } else if (*pos == ',') {
613 pos = parse_space(pos + 1, is_nested);
614
615 if (is_digit_char(*pos)) {
616 const char * int_end = parse_int(pos);
617 max_times = std::stoul(std::string(pos, int_end - pos));
618 pos = parse_space(int_end, is_nested);
619 }
620
621 if (*pos != '}') {
622 throw std::runtime_error(std::string("expecting '}' at ") + pos);
623 }
624 pos = parse_space(pos + 1, is_nested);
625 } else {
626 throw std::runtime_error(std::string("expecting ',' at ") + pos);
627 }
628 bool has_max = max_times != UINT64_MAX;
629 if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
630 throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
631 }
632 handle_repetitions(min_times, max_times);
633 } else {
634 break;
635 }
636 }
637 return pos;
638}
639
640const char * llama_grammar_parser::parse_rule(const char * src) {
641 const char * name_end = parse_name(src);
642 const char * pos = parse_space(name_end, false);
643 size_t name_len = name_end - src;
644 uint32_t rule_id = get_symbol_id(src, name_len);
645 const std::string name(src, name_len);
646
647 if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
648 throw std::runtime_error(std::string("expecting ::= at ") + pos);
649 }
650 pos = parse_space(pos + 3, true);
651
652 pos = parse_alternates(pos, name, rule_id, false);
653
654 if (*pos == '\r') {
655 pos += pos[1] == '\n' ? 2 : 1;
656 } else if (*pos == '\n') {
657 pos++;
658 } else if (*pos) {
659 throw std::runtime_error(std::string("expecting newline or end at ") + pos);
660 }
661 return parse_space(pos, true);
662}
663
664bool llama_grammar_parser::parse(const char * src) {
665 try {
666 const char * pos = parse_space(src, true);
667 while (*pos) {
668 pos = parse_rule(pos);
669 }
670 // Validate the state to ensure that all rules are defined
671 for (const auto & rule : rules) {
672 if (rule.empty()) {
673 throw std::runtime_error("Undefined rule");
674 }
675 for (const auto & elem : rule) {
676 if (elem.type == LLAMA_GRETYPE_RULE_REF) {
677 // Ensure that the rule at that location exists
678 if (elem.value >= rules.size() || rules[elem.value].empty()) {
679 // Get the name of the rule that is missing
680 for (const auto & kv : symbol_ids) {
681 if (kv.second == elem.value) {
682 throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
683 }
684 }
685 }
686 }
687 }
688 }
689 } catch (const std::exception & err) {
690 fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
691 rules.clear();
692 return false;
693 }
694
695 return true;
696}
697
698void llama_grammar_parser::print(FILE * file) {
699 try {
700 std::map<uint32_t, std::string> symbol_id_names;
701 for (const auto & kv : symbol_ids) {
702 symbol_id_names[kv.second] = kv.first;
703 }
704 for (size_t i = 0, end = rules.size(); i < end; i++) {
705 // fprintf(file, "%zu: ", i);
706 // print_rule_binary(file, rules[i]);
707 print_rule(file, uint32_t(i), rules[i], symbol_id_names);
708 // fprintf(file, "\n");
709 }
710 } catch (const std::exception & err) {
711 fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
712 }
713}
714
715llama_grammar_stack llama_grammar_parser::c_rules() const {
716 llama_grammar_stack ret;
717 ret.reserve(rules.size());
718 for (const auto & rule : rules) {
719 ret.push_back(rule.data());
720 }
721 return ret;
722}
723
724// returns true iff pos points to the end of one of the definitions of a rule
725static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
726 switch (pos->type) {
727 case LLAMA_GRETYPE_END: return true; // NOLINT
728 case LLAMA_GRETYPE_ALT: return true; // NOLINT
729 default: return false;
730 }
731}
732
733// returns true iff chr satisfies the char range at pos (regular or inverse range)
734// asserts that pos is pointing to a char range element
735static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
736 const llama_grammar_element * pos,
737 const uint32_t chr) {
738 bool found = false;
739 bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
740
741 GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
742
743 do {
744 if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
745 // inclusive range, e.g. [a-z]
746 found = found || (pos->value <= chr && chr <= pos[1].value);
747 pos += 2;
748 } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
749 // Any character matches "."
750 found = true;
751 pos += 1;
752 } else {
753 // exact char match, e.g. [a] or "a"
754 found = found || pos->value == chr;
755 pos += 1;
756 }
757 } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
758
759 return std::make_pair(found == is_positive_char, pos);
760}
761
762// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
763// range at pos (regular or inverse range)
764// asserts that pos is pointing to a char range element
765static bool llama_grammar_match_partial_char(
766 const llama_grammar_element * pos,
767 const llama_partial_utf8 partial_utf8) {
768 bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
769 GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
770
771 uint32_t partial_value = partial_utf8.value;
772 int n_remain = partial_utf8.n_remain;
773
774 // invalid sequence or 7-bit char split across 2 bytes (overlong)
775 if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
776 return false;
777 }
778
779 // range of possible code points this partial UTF-8 sequence could complete to
780 uint32_t low = partial_value << (n_remain * 6);
781 uint32_t high = low | ((1 << (n_remain * 6)) - 1);
782
783 if (low == 0) {
784 if (n_remain == 2) {
785 low = 1 << 11;
786 } else if (n_remain == 3) {
787 low = 1 << 16;
788 }
789 }
790
791 do {
792 if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
793 // inclusive range, e.g. [a-z]
794 if (pos->value <= high && low <= pos[1].value) {
795 return is_positive_char;
796 }
797 pos += 2;
798 } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
799 // Any character matches "."
800 return true;
801 } else {
802 // exact char match, e.g. [a] or "a"
803 if (low <= pos->value && pos->value <= high) {
804 return is_positive_char;
805 }
806 pos += 1;
807 }
808 } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
809
810 return !is_positive_char;
811}
812
813// returns true iff token matches the rule at pos (regular or inverse)
814// asserts that pos is pointing to a token element
815static bool llama_grammar_match_token(
816 const llama_grammar_element * pos,
817 const llama_token token) {
818 GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
819 if (pos->type == LLAMA_GRETYPE_TOKEN) {
820 return pos->value == static_cast<uint32_t>(token);
821 }
822 if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
823 return pos->value != static_cast<uint32_t>(token);
824 }
825 return false;
826}
827
828// transforms a grammar pushdown stack into N possible stacks, all ending
829// at a character range (terminal element)
830static void llama_grammar_advance_stack(
831 const llama_grammar_rules & rules,
832 const llama_grammar_stack & stack,
833 llama_grammar_stacks & new_stacks) {
834 if (stack.empty()) {
835 if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
836 new_stacks.emplace_back(stack);
837 }
838 return;
839 }
840
841 const llama_grammar_element * pos = stack.back();
842
843 switch (pos->type) {
844 case LLAMA_GRETYPE_RULE_REF: {
845 const size_t rule_id = static_cast<size_t>(pos->value);
846 const llama_grammar_element * subpos = rules[rule_id].data();
847 do {
848 // init new stack without the top (pos)
849 llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
850 if (!llama_grammar_is_end_of_sequence(pos + 1)) {
851 // if this rule ref is followed by another element, add that to stack
852 new_stack.push_back(pos + 1);
853 }
854 if (!llama_grammar_is_end_of_sequence(subpos)) {
855 // if alternate is nonempty, add to stack
856 new_stack.push_back(subpos);
857 }
858 llama_grammar_advance_stack(rules, new_stack, new_stacks);
859 while (!llama_grammar_is_end_of_sequence(subpos)) {
860 // scan to end of alternate def
861 subpos++;
862 }
863 if (subpos->type == LLAMA_GRETYPE_ALT) {
864 // there's another alternate def of this rule to process
865 subpos++;
866 } else {
867 break;
868 }
869 } while (true);
870 break;
871 }
872 case LLAMA_GRETYPE_CHAR:
873 case LLAMA_GRETYPE_CHAR_NOT:
874 case LLAMA_GRETYPE_CHAR_ANY:
875 case LLAMA_GRETYPE_TOKEN:
876 case LLAMA_GRETYPE_TOKEN_NOT:
877 if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
878 // only add the stack if it's not a duplicate of one we already have
879 new_stacks.emplace_back(stack);
880 }
881 break;
882 default:
883 // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
884 // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
885 // those
886 GGML_ABORT("fatal error");
887 }
888}
889
890static llama_grammar_candidates llama_grammar_reject_candidates(
891 const llama_grammar_rules & rules,
892 const llama_grammar_stacks & stacks,
893 const llama_grammar_candidates & candidates) {
894 GGML_ASSERT(!stacks.empty()); // REVIEW
895
896 if (candidates.empty()) {
897 return {};
898 }
899
900 auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
901
902 for (size_t i = 1, size = stacks.size(); i < size; ++i) {
903 rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
904 }
905
906 return rejects;
907}
908
909static bool llama_grammar_detect_left_recursion(
910 const llama_grammar_rules & rules,
911 size_t rule_index,
912 std::vector<bool> * rules_visited,
913 std::vector<bool> * rules_in_progress,
914 std::vector<bool> * rules_may_be_empty) {
915 if ((*rules_in_progress)[rule_index]) {
916 return true;
917 }
918
919 (*rules_in_progress)[rule_index] = true;
920
921 const llama_grammar_rule & rule = rules[rule_index];
922
923 // First check if the rule might produce the empty string. This could be done combined with the second
924 // step but it's more readable as two steps.
925 bool at_rule_start = true;
926 for (size_t i = 0; i < rule.size(); i++) {
927 if (llama_grammar_is_end_of_sequence(&rule[i])) {
928 if (at_rule_start) {
929 (*rules_may_be_empty)[rule_index] = true;
930 break;
931 }
932 at_rule_start = true;
933 } else {
934 at_rule_start = false;
935 }
936 }
937
938 // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
939 // be empty)
940 bool recurse_into_nonterminal = true;
941 for (size_t i = 0; i < rule.size(); i++) {
942 if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
943 if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
944 return true;
945 }
946 if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
947 recurse_into_nonterminal = false;
948 }
949 } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
950 recurse_into_nonterminal = true;
951 } else {
952 recurse_into_nonterminal = false;
953 }
954 }
955
956 (*rules_in_progress)[rule_index] = false;
957 (*rules_visited)[rule_index] = true;
958
959 return false;
960}
961
962const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
963 return grammar->rules;
964}
965
966llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
967 return grammar->stacks;
968}
969
970static void llama_grammar_accept_chr(
971 struct llama_grammar & grammar,
972 const llama_grammar_stack & stack,
973 uint32_t chr,
974 llama_grammar_stacks & new_stacks) {
975 if (stack.empty()) {
976 return;
977 }
978
979 const llama_grammar_element * pos = stack.back();
980
981 // ignore if this turns into a token
982 if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
983 return;
984 }
985
986 auto match = llama_grammar_match_char(pos, chr);
987 if (match.first) {
988 llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
989 if (!llama_grammar_is_end_of_sequence(match.second)) {
990 new_stack.push_back(match.second);
991 }
992 llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
993 }
994}
995
996void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
997 llama_grammar_stacks stacks_new;
998 stacks_new.reserve(grammar->stacks.size());
999
1000 for (const auto & stack : grammar->stacks) {
1001 llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
1002 }
1003
1004 grammar->stacks = std::move(stacks_new);
1005}
1006
1007llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
1008 const llama_grammar_rules & rules,
1009 const llama_grammar_stack & stack,
1010 const llama_grammar_candidates & candidates) {
1011
1012 llama_grammar_candidates rejects;
1013 rejects.reserve(candidates.size());
1014
1015 if (stack.empty()) {
1016 for (const auto & tok : candidates) {
1017 if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
1018 rejects.push_back(tok);
1019 }
1020 }
1021 return rejects;
1022 }
1023
1024 const llama_grammar_element * stack_pos = stack.back();
1025
1026 // if the top of the stack is a token rule, then we only need to check the token id
1027 if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1028 for (const auto & tok : candidates) {
1029 if (*tok.code_points == 0) {
1030 // reached the end of a token consumed by char rules, reject iff it ended
1031 // in a partial response
1032 if (tok.partial_utf8.n_remain != 0) {
1033 rejects.push_back(tok);
1034 }
1035 } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
1036 rejects.push_back(tok);
1037 }
1038 }
1039 return rejects;
1040 }
1041
1042 llama_grammar_candidates next_candidates;
1043 next_candidates.reserve(candidates.size());
1044
1045 for (const auto & tok : candidates) {
1046 if (*tok.code_points == 0) {
1047 // reached end of full codepoints in token, reject iff it ended in a partial sequence
1048 // that cannot satisfy this position in grammar
1049 if (tok.partial_utf8.n_remain != 0 &&
1050 !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
1051 rejects.push_back(tok);
1052 }
1053 } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
1054 next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
1055 } else {
1056 rejects.push_back(tok);
1057 }
1058 }
1059
1060 const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
1061
1062 // update top of stack to next element, if any
1063 llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
1064 if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
1065 stack_after.push_back(stack_pos_after);
1066 }
1067 llama_grammar_stacks next_stacks;
1068 llama_grammar_advance_stack(rules, stack_after, next_stacks);
1069
1070 auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
1071 for (const auto & tok : next_rejects) {
1072 rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
1073 }
1074
1075 return rejects;
1076}
1077
1078////////////////////
1079
1080struct llama_grammar * llama_grammar_init_impl(
1081 const struct llama_vocab * vocab,
1082 const llama_grammar_element ** rules,
1083 size_t n_rules,
1084 size_t start_rule_index) {
1085 const llama_grammar_element * pos;
1086
1087 // copy rule definitions into vectors
1088 llama_grammar_rules vec_rules(n_rules);
1089 for (size_t i = 0; i < n_rules; i++) {
1090 for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1091 vec_rules[i].push_back(*pos);
1092 }
1093 vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1094 }
1095
1096 // Check for left recursion
1097 std::vector<bool> rules_visited(n_rules);
1098 std::vector<bool> rules_in_progress(n_rules);
1099 std::vector<bool> rules_may_be_empty(n_rules);
1100 for (size_t i = 0; i < n_rules; i++) {
1101 if (rules_visited[i]) {
1102 continue;
1103 }
1104 if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1105 LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1106 return nullptr;
1107 }
1108 }
1109
1110 // loop over alternates of start rule to build initial stacks
1111 llama_grammar_stacks stacks;
1112 pos = vec_rules[start_rule_index].data();
1113 do {
1114 llama_grammar_stack stack;
1115 if (!llama_grammar_is_end_of_sequence(pos)) {
1116 // if alternate is nonempty, add to stack
1117 stack.push_back(pos);
1118 }
1119 llama_grammar_advance_stack(vec_rules, stack, stacks);
1120 while (!llama_grammar_is_end_of_sequence(pos)) {
1121 // scan to end of alternate def
1122 pos++;
1123 }
1124 if (pos->type == LLAMA_GRETYPE_ALT) {
1125 // there's another alternate def of this rule to process
1126 pos++;
1127 } else {
1128 break;
1129 }
1130 } while (true);
1131
1132 // Important: vec_rules has to be moved here, not copied, because stacks contains
1133 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1134 // then the pointers would be invalidated when the local vec_rules goes out of scope.
1135 return new llama_grammar {
1136 vocab,
1137 std::move(vec_rules),
1138 std::move(stacks),
1139 /* .partial_utf8 = */ {},
1140 /* .lazy = */ false,
1141 /* .awaiting_trigger = */ false,
1142 /* .trigger_buffer = */ "",
1143 /* .trigger_buffer_positions = */ {},
1144 /* .trigger_tokens = */ {},
1145 /* .trigger_patterns = */ {},
1146 };
1147}
1148
1149struct llama_grammar * llama_grammar_init_impl(
1150 const struct llama_vocab * vocab,
1151 const char * grammar_str,
1152 const char * grammar_root,
1153 bool lazy,
1154 const char ** trigger_patterns,
1155 size_t num_trigger_patterns,
1156 const llama_token * trigger_tokens,
1157 size_t num_trigger_tokens) {
1158 llama_grammar_parser parser(vocab);
1159
1160 // if there is a grammar, parse it
1161 // rules will be empty (default) if there are parse errors
1162 if (!parser.parse(grammar_str) || parser.rules.empty()) {
1163 fprintf(stderr, "%s: failed to parse grammar\n", __func__);
1164 return nullptr;
1165 }
1166
1167 // Ensure that there is a "root" node.
1168 if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
1169 fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
1170 return nullptr;
1171 }
1172
1173 std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
1174
1175 const size_t n_rules = grammar_rules.size();
1176 const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
1177
1178 const llama_grammar_element * pos;
1179
1180 // copy rule definitions into vectors
1181 llama_grammar_rules vec_rules(n_rules);
1182 for (size_t i = 0; i < n_rules; i++) {
1183 for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
1184 vec_rules[i].push_back(*pos);
1185 }
1186 vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1187 }
1188
1189 // Check for left recursion
1190 std::vector<bool> rules_visited(n_rules);
1191 std::vector<bool> rules_in_progress(n_rules);
1192 std::vector<bool> rules_may_be_empty(n_rules);
1193 for (size_t i = 0; i < n_rules; i++) {
1194 if (rules_visited[i]) {
1195 continue;
1196 }
1197 if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
1198 LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
1199 return nullptr;
1200 }
1201 }
1202
1203 // loop over alternates of start rule to build initial stacks
1204 llama_grammar_stacks stacks;
1205 pos = vec_rules[start_rule_index].data();
1206 do {
1207 llama_grammar_stack stack;
1208 if (!llama_grammar_is_end_of_sequence(pos)) {
1209 // if alternate is nonempty, add to stack
1210 stack.push_back(pos);
1211 }
1212 llama_grammar_advance_stack(vec_rules, stack, stacks);
1213 while (!llama_grammar_is_end_of_sequence(pos)) {
1214 // scan to end of alternate def
1215 pos++;
1216 }
1217 if (pos->type == LLAMA_GRETYPE_ALT) {
1218 // there's another alternate def of this rule to process
1219 pos++;
1220 } else {
1221 break;
1222 }
1223 } while (true);
1224
1225 std::vector<llama_token> vec_trigger_tokens;
1226 std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
1227 for (size_t i = 0; i < num_trigger_tokens; i++) {
1228 GGML_ASSERT(trigger_tokens != nullptr);
1229 vec_trigger_tokens.push_back(trigger_tokens[i]);
1230 }
1231 for (size_t i = 0; i < num_trigger_patterns; i++) {
1232 GGML_ASSERT(trigger_patterns != nullptr);
1233 auto & trigger = vec_trigger_patterns.emplace_back();
1234 trigger.pattern = trigger_patterns[i];
1235 trigger.regex = std::regex(trigger.pattern);
1236 }
1237
1238 // Important: vec_rules has to be moved here, not copied, because stacks contains
1239 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1240 // then the pointers would be invalidated when the local vec_rules goes out of scope.
1241 return new llama_grammar {
1242 vocab,
1243 std::move(vec_rules),
1244 std::move(stacks),
1245 /* .partial_utf8 = */ {},
1246 /* .lazy = */ lazy,
1247 /* .awaiting_trigger = */ lazy,
1248 /* .trigger_buffer = */ "",
1249 /* .trigger_buffer_positions = */ {},
1250 std::move(vec_trigger_tokens),
1251 std::move(vec_trigger_patterns),
1252 };
1253}
1254
1255void llama_grammar_free_impl(struct llama_grammar * grammar) {
1256 if (grammar == nullptr) {
1257 return;
1258 }
1259
1260 delete grammar;
1261}
1262
1263struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1264 auto * result = new llama_grammar {
1265 grammar.vocab,
1266 grammar.rules,
1267 grammar.stacks,
1268 grammar.partial_utf8,
1269 grammar.lazy,
1270 grammar.awaiting_trigger,
1271 grammar.trigger_buffer,
1272 grammar.trigger_buffer_positions,
1273 grammar.trigger_tokens,
1274 grammar.trigger_patterns,
1275 };
1276
1277 // redirect elements in stacks to point to new rules
1278 for (size_t is = 0; is < result->stacks.size(); is++) {
1279 for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
1280 for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
1281 for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
1282 if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1283 result->stacks[is][ie] = &result->rules[ir0][ir1];
1284 }
1285 }
1286 }
1287 }
1288 }
1289
1290 return result;
1291}
1292
1293void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1294 GGML_ASSERT(grammar.vocab != nullptr);
1295
1296 if (grammar.awaiting_trigger) {
1297 return;
1298 }
1299
1300 bool allow_eog = false;
1301 for (const auto & stack : grammar.stacks) {
1302 if (stack.empty()) {
1303 allow_eog = true;
1304 break;
1305 }
1306 }
1307
1308 std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1309 candidates_decoded.reserve(cur_p->size);
1310
1311 llama_grammar_candidates candidates_grammar;
1312 candidates_grammar.reserve(cur_p->size);
1313
1314 for (size_t i = 0; i < cur_p->size; ++i) {
1315 const llama_token id = cur_p->data[i].id;
1316 const std::string & piece = grammar.vocab->token_to_piece(id);
1317
1318 if (grammar.vocab->is_eog(id)) {
1319 if (!allow_eog) {
1320 cur_p->data[i].logit = -INFINITY;
1321 }
1322 } else if (piece.empty() || piece[0] == 0) {
1323 cur_p->data[i].logit = -INFINITY;
1324 } else {
1325 candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1326 candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
1327 }
1328 }
1329
1330 const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
1331 for (const auto & reject : rejects) {
1332 cur_p->data[reject.index].logit = -INFINITY;
1333 }
1334}
1335
1336void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
1337 GGML_ASSERT(grammar.vocab != nullptr);
1338
1339 const auto & piece = grammar.vocab->token_to_piece(token);
1340
1341 if (grammar.awaiting_trigger) {
1342 if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1343 grammar.awaiting_trigger = false;
1344 grammar.trigger_buffer.clear();
1345 llama_grammar_accept_token(grammar, token, piece);
1346 LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1347 return;
1348 } else {
1349 auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
1350 grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
1351 grammar.trigger_buffer += piece;
1352
1353 for (const auto & trigger_pattern : grammar.trigger_patterns) {
1354 auto start = trigger_pattern.find(grammar.trigger_buffer);
1355 if (start != std::string::npos) {
1356 grammar.awaiting_trigger = false;
1357
1358 // replay tokens that overlap with [start, end)
1359 for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
1360 auto [tok_start, tok_end] = tok_pos;
1361 if (tok_end <= start) {
1362 continue;
1363 }
1364
1365 size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
1366 size_t piece_len = tok_end - piece_start;
1367 auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
1368 llama_grammar_accept_token(grammar, tok, tok_piece);
1369 }
1370
1371 auto constrained_str = grammar.trigger_buffer.substr(start);
1372 grammar.trigger_buffer.clear();
1373 grammar.trigger_buffer_positions.clear();
1374 LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1375 return;
1376 }
1377 }
1378 LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
1379 return;
1380 }
1381 }
1382
1383 if (grammar.vocab->is_eog(token)) {
1384 for (const auto & stack : grammar.stacks) {
1385 if (stack.empty()) {
1386 return;
1387 }
1388 }
1389 GGML_ABORT("fatal error");
1390 }
1391
1392 llama_grammar_accept_token(grammar, token, piece);
1393}
1394
1395void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
1396 // Note terminating 0 in decoded string
1397 const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1398 const auto & code_points = decoded.first;
1399
1400 for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1401 llama_grammar_accept(&grammar, *it);
1402 }
1403
1404 grammar.partial_utf8 = decoded.second;
1405 if (grammar.stacks.empty()) {
1406 throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1407 }
1408}
1409
1410void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
1411 // Note terminating 0 in decoded string
1412 const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1413 const auto & code_points = decoded.first;
1414
1415 llama_grammar_stacks stacks_new;
1416 stacks_new.reserve(grammar.stacks.size());
1417
1418 for (const auto & stack : grammar.stacks) {
1419 if (stack.empty()) {
1420 continue;
1421 }
1422
1423 const llama_grammar_element * pos = stack.back();
1424
1425 if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1426 if (llama_grammar_match_token(pos, token)) {
1427 llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1428 if (!llama_grammar_is_end_of_sequence(pos + 1)) {
1429 new_stack.push_back(pos + 1);
1430 }
1431 llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
1432 }
1433 } else {
1434 llama_grammar_stacks current_stacks = {stack};
1435
1436 for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1437 llama_grammar_stacks next_stacks;
1438
1439 for (const auto & cur_stack : current_stacks) {
1440 llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
1441 }
1442
1443 current_stacks = std::move(next_stacks);
1444 if (current_stacks.empty()) {
1445 break;
1446 }
1447 }
1448
1449 for (auto & surviving_stack : current_stacks) {
1450 if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
1451 stacks_new.emplace_back(surviving_stack);
1452 }
1453 }
1454 }
1455 }
1456
1457 grammar.stacks = std::move(stacks_new);
1458 grammar.partial_utf8 = decoded.second;
1459
1460 if (grammar.stacks.empty()) {
1461 throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
1462 }
1463}
1464