1#ifdef NDEBUG
2#undef NDEBUG
3#endif
4
5#include "llama.h"
6
7// TODO: shold not include libllama sources
8#include "../src/llama-grammar.h"
9
10#include <cassert>
11
12static const char * type_str(llama_gretype type) {
13 switch (type) {
14 case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR";
15 case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT";
16 case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT";
17 case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER";
18 case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF";
19 case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT";
20 case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END";
21 default: return "?";
22 }
23}
24
25static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
26 uint32_t index = 0;
27 llama_grammar_parser parsed_grammar;
28 parsed_grammar.parse(grammar_bytes);
29
30 std::map<uint32_t, std::string> symbol_names;
31 for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
32 symbol_names[it->second] = it->first;
33 }
34
35 auto print_all = [&]() {
36 fprintf(stderr, " verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes);
37 for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
38 fprintf(stderr, " {\"%s\", %u},\n", it->first.c_str(), it->second);
39 }
40 fprintf(stderr, " }, {\n");
41 for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) {
42 fprintf(stderr, " // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule);
43 auto & rule = parsed_grammar.rules[i_rule];
44 for (uint32_t i = 0; i < rule.size(); i++) {
45 std::string rule_str;
46 fprintf(stderr, " {%s, ", type_str(rule[i].type));
47 if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT ||
48 rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
49 char c = rule[i].value;
50 if (c == '\n') {
51 fprintf(stderr, "'\\n'");
52 } else if (c == '\t') {
53 fprintf(stderr, "'\\t'");
54 } else if (c == '\r') {
55 fprintf(stderr, "'\\r'");
56 } else if (c == '\0') {
57 fprintf(stderr, "'\\0'");
58 } else {
59 fprintf(stderr, "'%c'", c);
60 }
61 } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
62 fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value);
63 } else {
64 fprintf(stderr, "%u", rule[i].value);
65 }
66 fprintf(stderr, "},\n");
67 }
68 }
69 fprintf(stderr, " });\n");
70 };
71
72 if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) {
73 print_all();
74 fprintf(stderr, "\n");
75 return;
76 }
77
78 fprintf(stderr, "Testing grammar:%s\n", grammar_bytes);
79
80 if (parsed_grammar.symbol_ids.size() != expected.size()) {
81 fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
82 print_all();
83 assert(parsed_grammar.symbol_ids.size() == expected.size());
84 }
85
86 for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
87 {
88 std::string key = it->first;
89 uint32_t value = it->second;
90 std::pair<std::string, uint32_t> expected_pair = expected[index];
91
92 // pretty print error message before asserting
93 if (expected_pair.first != key || expected_pair.second != value)
94 {
95 fprintf(stderr, "index: %u\n", index);
96 fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
97 fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
98 fprintf(stderr, "expected_pair != actual_pair\n");
99 fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
100 print_all();
101 }
102
103 assert(expected_pair.first == key && expected_pair.second == value);
104
105 index++;
106 }
107
108 index = 0;
109 for (auto rule : parsed_grammar.rules)
110 {
111 // compare rule to expected rule
112 for (uint32_t i = 0; i < rule.size(); i++)
113 {
114 llama_grammar_element element = rule[i];
115 llama_grammar_element expected_element = expected_rules[index];
116
117 // pretty print error message before asserting
118 if (expected_element.type != element.type || expected_element.value != element.value)
119 {
120 fprintf(stderr, "index: %u\n", index);
121 fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value);
122 fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value);
123 fprintf(stderr, "expected_element != actual_element\n");
124 fprintf(stderr, "all elements:\n");
125 fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
126 print_all();
127 }
128
129 assert(expected_element.type == element.type && expected_element.value == element.value);
130 index++;
131 }
132 }
133}
134
135static void verify_failure(const char * grammar_bytes) {
136 fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
137 llama_grammar_parser result;
138 result.parse(grammar_bytes);
139 assert(result.rules.empty() && "should have failed");
140}
141
142int main()
143{
144 verify_failure(R"""(
145 root ::= "a"{,}"
146 )""");
147
148 verify_failure(R"""(
149 root ::= "a"{,10}"
150 )""");
151
152 verify_parsing(R"""(
153 root ::= "a"
154 )""", {
155 {"root", 0},
156 }, {
157 // root (index 0)
158 {LLAMA_GRETYPE_CHAR, 'a'},
159 {LLAMA_GRETYPE_END, 0},
160 });
161
162 verify_parsing(R"""(
163 root ::= "a" | [bdx-z] | [^1-3]
164 )""", {
165 {"root", 0},
166 }, {
167 // root (index 0)
168 {LLAMA_GRETYPE_CHAR, 'a'},
169 {LLAMA_GRETYPE_ALT, 0},
170 {LLAMA_GRETYPE_CHAR, 'b'},
171 {LLAMA_GRETYPE_CHAR_ALT, 'd'},
172 {LLAMA_GRETYPE_CHAR_ALT, 'x'},
173 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
174 {LLAMA_GRETYPE_ALT, 0},
175 {LLAMA_GRETYPE_CHAR_NOT, '1'},
176 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'},
177 {LLAMA_GRETYPE_END, 0},
178 });
179
180 verify_parsing(R"""(
181 root ::= a+
182 a ::= "a"
183 )""", {
184 {"a", 1},
185 {"root", 0},
186 {"root_2", 2},
187 }, {
188 // root (index 0)
189 {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
190 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
191 {LLAMA_GRETYPE_END, 0},
192 // a (index 1)
193 {LLAMA_GRETYPE_CHAR, 'a'},
194 {LLAMA_GRETYPE_END, 0},
195 // root_2 (index 2)
196 {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
197 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
198 {LLAMA_GRETYPE_ALT, 0},
199 {LLAMA_GRETYPE_END, 0},
200 });
201
202 verify_parsing(R"""(
203 root ::= "a"+
204 )""", {
205 {"root", 0},
206 {"root_1", 1},
207 }, {
208 // root (index 0)
209 {LLAMA_GRETYPE_CHAR, 'a'},
210 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
211 {LLAMA_GRETYPE_END, 0},
212 // root_1 (index 1)
213 {LLAMA_GRETYPE_CHAR, 'a'},
214 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
215 {LLAMA_GRETYPE_ALT, 0},
216 {LLAMA_GRETYPE_END, 0},
217 });
218
219 verify_parsing(R"""(
220 root ::= a?
221 a ::= "a"
222 )""", {
223 {"a", 1},
224 {"root", 0},
225 {"root_2", 2},
226 }, {
227 // root (index 0)
228 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
229 {LLAMA_GRETYPE_END, 0},
230 // a (index 1)
231 {LLAMA_GRETYPE_CHAR, 'a'},
232 {LLAMA_GRETYPE_END, 0},
233 // root_2 (index 2)
234 {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
235 {LLAMA_GRETYPE_ALT, 0},
236 {LLAMA_GRETYPE_END, 0},
237 });
238
239 verify_parsing(R"""(
240 root ::= "a"?
241 )""", {
242 {"root", 0},
243 {"root_1", 1},
244 }, {
245 // root (index 0)
246 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
247 {LLAMA_GRETYPE_END, 0},
248 // root_1 (index 1)
249 {LLAMA_GRETYPE_CHAR, 'a'},
250 {LLAMA_GRETYPE_ALT, 0},
251 {LLAMA_GRETYPE_END, 0},
252 });
253
254 verify_parsing(R"""(
255 root ::= a*
256 a ::= "a"
257 )""", {
258 {"a", 1},
259 {"root", 0},
260 {"root_2", 2},
261 }, {
262 // root (index 0)
263 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
264 {LLAMA_GRETYPE_END, 0},
265 // a (index 1)
266 {LLAMA_GRETYPE_CHAR, 'a'},
267 {LLAMA_GRETYPE_END, 0},
268 // root_2 (index 2)
269 {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
270 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
271 {LLAMA_GRETYPE_ALT, 0},
272 {LLAMA_GRETYPE_END, 0},
273 });
274
275 verify_parsing(R"""(
276 root ::= "a"*
277 )""", {
278 {"root", 0},
279 {"root_1", 1},
280 }, {
281 // root (index 0)
282 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
283 {LLAMA_GRETYPE_END, 0},
284 // root_1 (index 1)
285 {LLAMA_GRETYPE_CHAR, 'a'},
286 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
287 {LLAMA_GRETYPE_ALT, 0},
288 {LLAMA_GRETYPE_END, 0},
289 });
290
291 verify_parsing(R"""(
292 root ::= "a"{2}
293 )""", {
294 {"root", 0},
295 }, {
296 // root (index 0)
297 {LLAMA_GRETYPE_CHAR, 'a'},
298 {LLAMA_GRETYPE_CHAR, 'a'},
299 {LLAMA_GRETYPE_END, 0},
300 });
301
302 verify_parsing(R"""(
303 root ::= "a"{2,}
304 )""", {
305 {"root", 0},
306 {"root_1", 1},
307 }, {
308 // root (index 0)
309 {LLAMA_GRETYPE_CHAR, 'a'},
310 {LLAMA_GRETYPE_CHAR, 'a'},
311 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
312 {LLAMA_GRETYPE_END, 0},
313 // root_1 (index 1)
314 {LLAMA_GRETYPE_CHAR, 'a'},
315 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
316 {LLAMA_GRETYPE_ALT, 0},
317 {LLAMA_GRETYPE_END, 0},
318 });
319
320 verify_parsing(R"""(
321 root ::= "a"{ 4}
322 )""", {
323 {"root", 0},
324 }, {
325 // root (index 0)
326 {LLAMA_GRETYPE_CHAR, 'a'},
327 {LLAMA_GRETYPE_CHAR, 'a'},
328 {LLAMA_GRETYPE_CHAR, 'a'},
329 {LLAMA_GRETYPE_CHAR, 'a'},
330 {LLAMA_GRETYPE_END, 0},
331 });
332
333 verify_parsing(R"""(
334 root ::= "a"{2,4}
335 )""", {
336 {"root", 0},
337 {"root_1", 1},
338 {"root_2", 2},
339 }, {
340 // root (index 0)
341 {LLAMA_GRETYPE_CHAR, 'a'},
342 {LLAMA_GRETYPE_CHAR, 'a'},
343 {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
344 {LLAMA_GRETYPE_END, 0},
345 // root_1 (index 1)
346 {LLAMA_GRETYPE_CHAR, 'a'},
347 {LLAMA_GRETYPE_ALT, 0},
348 {LLAMA_GRETYPE_END, 0},
349 // root_2 (index 2)
350 {LLAMA_GRETYPE_CHAR, 'a'},
351 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
352 {LLAMA_GRETYPE_ALT, 0},
353 {LLAMA_GRETYPE_END, 0},
354 });
355
356 verify_parsing(R"""(
357 root ::= (expr "=" term "\n")+
358 expr ::= term ([-+*/] term)*
359 term ::= [0-9]+
360 )""", {
361 {"expr", 2},
362 {"expr_5", 5},
363 {"expr_6", 6},
364 {"root", 0},
365 {"root_1", 1},
366 {"root_4", 4},
367 {"term", 3},
368 {"term_7", 7},
369 }, {
370 // root (index 0)
371 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
372 {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
373 {LLAMA_GRETYPE_END, 0},
374 // root_1 (index 1)
375 {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
376 {LLAMA_GRETYPE_CHAR, '='},
377 {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
378 {LLAMA_GRETYPE_CHAR, '\n'},
379 {LLAMA_GRETYPE_END, 0},
380 // expr (index 2)
381 {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
382 {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
383 {LLAMA_GRETYPE_END, 0},
384 // term (index 3)
385 {LLAMA_GRETYPE_CHAR, '0'},
386 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
387 {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
388 {LLAMA_GRETYPE_END, 0},
389 // root_4 (index 4)
390 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
391 {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
392 {LLAMA_GRETYPE_ALT, 0},
393 {LLAMA_GRETYPE_END, 0},
394 // expr_5 (index 5)
395 {LLAMA_GRETYPE_CHAR, '-'},
396 {LLAMA_GRETYPE_CHAR_ALT, '+'},
397 {LLAMA_GRETYPE_CHAR_ALT, '*'},
398 {LLAMA_GRETYPE_CHAR_ALT, '/'},
399 {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
400 {LLAMA_GRETYPE_END, 0},
401 // expr_6 (index 6)
402 {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5},
403 {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
404 {LLAMA_GRETYPE_ALT, 0},
405 {LLAMA_GRETYPE_END, 0},
406 // term_7 (index 7)
407 {LLAMA_GRETYPE_CHAR, '0'},
408 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
409 {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
410 {LLAMA_GRETYPE_ALT, 0},
411 {LLAMA_GRETYPE_END, 0},
412 });
413
414 verify_parsing(R"""(
415 root ::= (expr "=" ws term "\n")+
416 expr ::= term ([-+*/] term)*
417 term ::= ident | num | "(" ws expr ")" ws
418 ident ::= [a-z] [a-z0-9_]* ws
419 num ::= [0-9]+ ws
420 ws ::= [ \t\n]*
421 )""", {
422 {"expr", 2},
423 {"expr_6", 6},
424 {"expr_7", 7},
425 {"ident", 8},
426 {"ident_10", 10},
427 {"num", 9},
428 {"num_11", 11},
429 {"root", 0},
430 {"root_1", 1},
431 {"root_5", 5},
432 {"term", 4},
433 {"ws", 3},
434 {"ws_12", 12},
435 }, {
436 // root (index 0)
437 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
438 {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
439 {LLAMA_GRETYPE_END, 0},
440 // root_1 (index 1)
441 {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
442 {LLAMA_GRETYPE_CHAR, '='},
443 {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
444 {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
445 {LLAMA_GRETYPE_CHAR, '\n'},
446 {LLAMA_GRETYPE_END, 0},
447 // expr (index 2)
448 {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
449 {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
450 {LLAMA_GRETYPE_END, 0},
451 // ws (index 3)
452 {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
453 {LLAMA_GRETYPE_END, 0},
454 // term (index 4)
455 {LLAMA_GRETYPE_RULE_REF, /* ident */ 8},
456 {LLAMA_GRETYPE_ALT, 0},
457 {LLAMA_GRETYPE_RULE_REF, /* num */ 9},
458 {LLAMA_GRETYPE_ALT, 0},
459 {LLAMA_GRETYPE_CHAR, '('},
460 {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
461 {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
462 {LLAMA_GRETYPE_CHAR, ')'},
463 {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
464 {LLAMA_GRETYPE_END, 0},
465 // root_5 (index 5)
466 {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
467 {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
468 {LLAMA_GRETYPE_ALT, 0},
469 {LLAMA_GRETYPE_END, 0},
470 // expr_6 (index 6)
471 {LLAMA_GRETYPE_CHAR, '-'},
472 {LLAMA_GRETYPE_CHAR_ALT, '+'},
473 {LLAMA_GRETYPE_CHAR_ALT, '*'},
474 {LLAMA_GRETYPE_CHAR_ALT, '/'},
475 {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
476 {LLAMA_GRETYPE_END, 0},
477 // expr_7 (index 7)
478 {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
479 {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
480 {LLAMA_GRETYPE_ALT, 0},
481 {LLAMA_GRETYPE_END, 0},
482 // ident (index 8)
483 {LLAMA_GRETYPE_CHAR, 'a'},
484 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
485 {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
486 {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
487 {LLAMA_GRETYPE_END, 0},
488 // num (index 9)
489 {LLAMA_GRETYPE_CHAR, '0'},
490 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
491 {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
492 {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
493 {LLAMA_GRETYPE_END, 0},
494 // ident_10 (index 10)
495 {LLAMA_GRETYPE_CHAR, 'a'},
496 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
497 {LLAMA_GRETYPE_CHAR_ALT, '0'},
498 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
499 {LLAMA_GRETYPE_CHAR_ALT, '_'},
500 {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
501 {LLAMA_GRETYPE_ALT, 0},
502 {LLAMA_GRETYPE_END, 0},
503 // num_11 (index 11)
504 {LLAMA_GRETYPE_CHAR, '0'},
505 {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
506 {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
507 {LLAMA_GRETYPE_ALT, 0},
508 {LLAMA_GRETYPE_END, 0},
509 // ws_12 (index 12)
510 {LLAMA_GRETYPE_CHAR, ' '},
511 {LLAMA_GRETYPE_CHAR_ALT, '\t'},
512 {LLAMA_GRETYPE_CHAR_ALT, '\n'},
513 {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
514 {LLAMA_GRETYPE_ALT, 0},
515 {LLAMA_GRETYPE_END, 0},
516 });
517
518 // <[1000]> = "<think>"
519 // <[1001]> = "</think>"
520 verify_parsing(R"""(
521 root ::= <[1000]> !<[1001]> <[1001]>
522 )""", {
523 {"root", 0}
524 }, {
525 // root (index 0)
526 {LLAMA_GRETYPE_TOKEN, 1000},
527 {LLAMA_GRETYPE_TOKEN_NOT, 1001},
528 {LLAMA_GRETYPE_TOKEN, 1001},
529 {LLAMA_GRETYPE_END, 0},
530 });
531
532 return 0;
533}