1#ifdef NDEBUG
  2#undef NDEBUG
  3#endif
  4
  5#include "llama.h"
  6
  7// TODO: shold not include libllama sources
  8#include "../src/llama-grammar.h"
  9
 10#include <cassert>
 11
 12static const char * type_str(llama_gretype type) {
 13    switch (type) {
 14        case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR";
 15        case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT";
 16        case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT";
 17        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER";
 18        case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF";
 19        case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT";
 20        case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END";
 21        default: return "?";
 22    }
 23}
 24
 25static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
 26    uint32_t index = 0;
 27    llama_grammar_parser parsed_grammar;
 28    parsed_grammar.parse(grammar_bytes);
 29
 30    std::map<uint32_t, std::string> symbol_names;
 31    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
 32        symbol_names[it->second] = it->first;
 33    }
 34
 35    auto print_all = [&]() {
 36        fprintf(stderr, "    verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes);
 37        for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
 38            fprintf(stderr, "        {\"%s\", %u},\n", it->first.c_str(), it->second);
 39        }
 40        fprintf(stderr, "    }, {\n");
 41        for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) {
 42            fprintf(stderr, "        // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule);
 43            auto & rule = parsed_grammar.rules[i_rule];
 44            for (uint32_t i = 0; i < rule.size(); i++) {
 45                std::string rule_str;
 46                fprintf(stderr, "        {%s, ", type_str(rule[i].type));
 47                if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT ||
 48                    rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
 49                    char c = rule[i].value;
 50                    if (c == '\n') {
 51                        fprintf(stderr, "'\\n'");
 52                    } else if (c == '\t') {
 53                        fprintf(stderr, "'\\t'");
 54                    } else if (c == '\r') {
 55                        fprintf(stderr, "'\\r'");
 56                    } else if (c == '\0') {
 57                        fprintf(stderr, "'\\0'");
 58                    } else {
 59                        fprintf(stderr, "'%c'", c);
 60                    }
 61                } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
 62                    fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value);
 63                } else {
 64                    fprintf(stderr, "%u", rule[i].value);
 65                }
 66                fprintf(stderr, "},\n");
 67            }
 68        }
 69        fprintf(stderr, "    });\n");
 70    };
 71
 72    if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) {
 73        print_all();
 74        fprintf(stderr, "\n");
 75        return;
 76    }
 77
 78    fprintf(stderr, "Testing grammar:%s\n", grammar_bytes);
 79
 80    if (parsed_grammar.symbol_ids.size() != expected.size()) {
 81        fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
 82        print_all();
 83        assert(parsed_grammar.symbol_ids.size() == expected.size());
 84    }
 85
 86    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
 87    {
 88        std::string key = it->first;
 89        uint32_t value = it->second;
 90        std::pair<std::string, uint32_t> expected_pair = expected[index];
 91
 92        // pretty print error message before asserting
 93        if (expected_pair.first != key || expected_pair.second != value)
 94        {
 95            fprintf(stderr, "index: %u\n", index);
 96            fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
 97            fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
 98            fprintf(stderr, "expected_pair != actual_pair\n");
 99            fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
100            print_all();
101        }
102
103        assert(expected_pair.first == key && expected_pair.second == value);
104
105        index++;
106    }
107
108    index = 0;
109    for (auto rule : parsed_grammar.rules)
110    {
111        // compare rule to expected rule
112        for (uint32_t i = 0; i < rule.size(); i++)
113        {
114            llama_grammar_element element = rule[i];
115            llama_grammar_element expected_element = expected_rules[index];
116
117            // pretty print error message before asserting
118            if (expected_element.type != element.type || expected_element.value != element.value)
119            {
120                fprintf(stderr, "index: %u\n", index);
121                fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value);
122                fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value);
123                fprintf(stderr, "expected_element != actual_element\n");
124                fprintf(stderr, "all elements:\n");
125                fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
126                print_all();
127            }
128
129            assert(expected_element.type == element.type && expected_element.value == element.value);
130            index++;
131        }
132    }
133}
134
135static void verify_failure(const char * grammar_bytes) {
136    fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
137    llama_grammar_parser result;
138    result.parse(grammar_bytes);
139    assert(result.rules.empty() && "should have failed");
140}
141
142int main()
143{
144    verify_failure(R"""(
145        root ::= "a"{,}"
146    )""");
147
148    verify_failure(R"""(
149        root ::= "a"{,10}"
150    )""");
151
152    verify_parsing(R"""(
153        root  ::= "a"
154    )""", {
155        {"root", 0},
156    }, {
157        // root (index 0)
158        {LLAMA_GRETYPE_CHAR, 'a'},
159        {LLAMA_GRETYPE_END, 0},
160    });
161
162    verify_parsing(R"""(
163        root  ::= "a" | [bdx-z] | [^1-3]
164    )""", {
165        {"root", 0},
166    }, {
167        // root (index 0)
168        {LLAMA_GRETYPE_CHAR, 'a'},
169        {LLAMA_GRETYPE_ALT, 0},
170        {LLAMA_GRETYPE_CHAR, 'b'},
171        {LLAMA_GRETYPE_CHAR_ALT, 'd'},
172        {LLAMA_GRETYPE_CHAR_ALT, 'x'},
173        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
174        {LLAMA_GRETYPE_ALT, 0},
175        {LLAMA_GRETYPE_CHAR_NOT, '1'},
176        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'},
177        {LLAMA_GRETYPE_END, 0},
178    });
179
180    verify_parsing(R"""(
181        root  ::= a+
182        a     ::= "a"
183    )""", {
184        {"a", 1},
185        {"root", 0},
186        {"root_2", 2},
187    }, {
188        // root (index 0)
189        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
190        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
191        {LLAMA_GRETYPE_END, 0},
192        // a (index 1)
193        {LLAMA_GRETYPE_CHAR, 'a'},
194        {LLAMA_GRETYPE_END, 0},
195        // root_2 (index 2)
196        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
197        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
198        {LLAMA_GRETYPE_ALT, 0},
199        {LLAMA_GRETYPE_END, 0},
200    });
201
202    verify_parsing(R"""(
203        root  ::= "a"+
204    )""", {
205        {"root", 0},
206        {"root_1", 1},
207    }, {
208        // root (index 0)
209        {LLAMA_GRETYPE_CHAR, 'a'},
210        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
211        {LLAMA_GRETYPE_END, 0},
212        // root_1 (index 1)
213        {LLAMA_GRETYPE_CHAR, 'a'},
214        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
215        {LLAMA_GRETYPE_ALT, 0},
216        {LLAMA_GRETYPE_END, 0},
217    });
218
219    verify_parsing(R"""(
220        root  ::= a?
221        a     ::= "a"
222    )""", {
223        {"a", 1},
224        {"root", 0},
225        {"root_2", 2},
226    }, {
227        // root (index 0)
228        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
229        {LLAMA_GRETYPE_END, 0},
230        // a (index 1)
231        {LLAMA_GRETYPE_CHAR, 'a'},
232        {LLAMA_GRETYPE_END, 0},
233        // root_2 (index 2)
234        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
235        {LLAMA_GRETYPE_ALT, 0},
236        {LLAMA_GRETYPE_END, 0},
237    });
238
239    verify_parsing(R"""(
240        root  ::= "a"?
241    )""", {
242        {"root", 0},
243        {"root_1", 1},
244    }, {
245        // root (index 0)
246        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
247        {LLAMA_GRETYPE_END, 0},
248        // root_1 (index 1)
249        {LLAMA_GRETYPE_CHAR, 'a'},
250        {LLAMA_GRETYPE_ALT, 0},
251        {LLAMA_GRETYPE_END, 0},
252    });
253
254    verify_parsing(R"""(
255        root  ::= a*
256        a     ::= "a"
257    )""", {
258        {"a", 1},
259        {"root", 0},
260        {"root_2", 2},
261    }, {
262        // root (index 0)
263        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
264        {LLAMA_GRETYPE_END, 0},
265        // a (index 1)
266        {LLAMA_GRETYPE_CHAR, 'a'},
267        {LLAMA_GRETYPE_END, 0},
268        // root_2 (index 2)
269        {LLAMA_GRETYPE_RULE_REF, /* a */ 1},
270        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
271        {LLAMA_GRETYPE_ALT, 0},
272        {LLAMA_GRETYPE_END, 0},
273    });
274
275    verify_parsing(R"""(
276        root  ::= "a"*
277    )""", {
278        {"root", 0},
279        {"root_1", 1},
280    }, {
281        // root (index 0)
282        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
283        {LLAMA_GRETYPE_END, 0},
284        // root_1 (index 1)
285        {LLAMA_GRETYPE_CHAR, 'a'},
286        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
287        {LLAMA_GRETYPE_ALT, 0},
288        {LLAMA_GRETYPE_END, 0},
289    });
290
291    verify_parsing(R"""(
292        root  ::= "a"{2}
293    )""", {
294        {"root", 0},
295    }, {
296        // root (index 0)
297        {LLAMA_GRETYPE_CHAR, 'a'},
298        {LLAMA_GRETYPE_CHAR, 'a'},
299        {LLAMA_GRETYPE_END, 0},
300    });
301
302    verify_parsing(R"""(
303        root  ::= "a"{2,}
304    )""", {
305        {"root", 0},
306        {"root_1", 1},
307    }, {
308        // root (index 0)
309        {LLAMA_GRETYPE_CHAR, 'a'},
310        {LLAMA_GRETYPE_CHAR, 'a'},
311        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
312        {LLAMA_GRETYPE_END, 0},
313        // root_1 (index 1)
314        {LLAMA_GRETYPE_CHAR, 'a'},
315        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
316        {LLAMA_GRETYPE_ALT, 0},
317        {LLAMA_GRETYPE_END, 0},
318    });
319
320    verify_parsing(R"""(
321        root  ::= "a"{ 4}
322    )""", {
323        {"root", 0},
324    }, {
325        // root (index 0)
326        {LLAMA_GRETYPE_CHAR, 'a'},
327        {LLAMA_GRETYPE_CHAR, 'a'},
328        {LLAMA_GRETYPE_CHAR, 'a'},
329        {LLAMA_GRETYPE_CHAR, 'a'},
330        {LLAMA_GRETYPE_END, 0},
331    });
332
333    verify_parsing(R"""(
334        root  ::= "a"{2,4}
335    )""", {
336        {"root", 0},
337        {"root_1", 1},
338        {"root_2", 2},
339    }, {
340        // root (index 0)
341        {LLAMA_GRETYPE_CHAR, 'a'},
342        {LLAMA_GRETYPE_CHAR, 'a'},
343        {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2},
344        {LLAMA_GRETYPE_END, 0},
345        // root_1 (index 1)
346        {LLAMA_GRETYPE_CHAR, 'a'},
347        {LLAMA_GRETYPE_ALT, 0},
348        {LLAMA_GRETYPE_END, 0},
349        // root_2 (index 2)
350        {LLAMA_GRETYPE_CHAR, 'a'},
351        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
352        {LLAMA_GRETYPE_ALT, 0},
353        {LLAMA_GRETYPE_END, 0},
354    });
355
356    verify_parsing(R"""(
357        root  ::= (expr "=" term "\n")+
358        expr  ::= term ([-+*/] term)*
359        term  ::= [0-9]+
360    )""", {
361        {"expr", 2},
362        {"expr_5", 5},
363        {"expr_6", 6},
364        {"root", 0},
365        {"root_1", 1},
366        {"root_4", 4},
367        {"term", 3},
368        {"term_7", 7},
369    }, {
370        // root (index 0)
371        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
372        {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
373        {LLAMA_GRETYPE_END, 0},
374        // root_1 (index 1)
375        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
376        {LLAMA_GRETYPE_CHAR, '='},
377        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
378        {LLAMA_GRETYPE_CHAR, '\n'},
379        {LLAMA_GRETYPE_END, 0},
380        // expr (index 2)
381        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
382        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
383        {LLAMA_GRETYPE_END, 0},
384        // term (index 3)
385        {LLAMA_GRETYPE_CHAR, '0'},
386        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
387        {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
388        {LLAMA_GRETYPE_END, 0},
389        // root_4 (index 4)
390        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
391        {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4},
392        {LLAMA_GRETYPE_ALT, 0},
393        {LLAMA_GRETYPE_END, 0},
394        // expr_5 (index 5)
395        {LLAMA_GRETYPE_CHAR, '-'},
396        {LLAMA_GRETYPE_CHAR_ALT, '+'},
397        {LLAMA_GRETYPE_CHAR_ALT, '*'},
398        {LLAMA_GRETYPE_CHAR_ALT, '/'},
399        {LLAMA_GRETYPE_RULE_REF, /* term */ 3},
400        {LLAMA_GRETYPE_END, 0},
401        // expr_6 (index 6)
402        {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5},
403        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
404        {LLAMA_GRETYPE_ALT, 0},
405        {LLAMA_GRETYPE_END, 0},
406        // term_7 (index 7)
407        {LLAMA_GRETYPE_CHAR, '0'},
408        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
409        {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7},
410        {LLAMA_GRETYPE_ALT, 0},
411        {LLAMA_GRETYPE_END, 0},
412    });
413
414    verify_parsing(R"""(
415        root  ::= (expr "=" ws term "\n")+
416        expr  ::= term ([-+*/] term)*
417        term  ::= ident | num | "(" ws expr ")" ws
418        ident ::= [a-z] [a-z0-9_]* ws
419        num   ::= [0-9]+ ws
420        ws    ::= [ \t\n]*
421    )""", {
422        {"expr", 2},
423        {"expr_6", 6},
424        {"expr_7", 7},
425        {"ident", 8},
426        {"ident_10", 10},
427        {"num", 9},
428        {"num_11", 11},
429        {"root", 0},
430        {"root_1", 1},
431        {"root_5", 5},
432        {"term", 4},
433        {"ws", 3},
434        {"ws_12", 12},
435    }, {
436        // root (index 0)
437        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
438        {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
439        {LLAMA_GRETYPE_END, 0},
440        // root_1 (index 1)
441        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
442        {LLAMA_GRETYPE_CHAR, '='},
443        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
444        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
445        {LLAMA_GRETYPE_CHAR, '\n'},
446        {LLAMA_GRETYPE_END, 0},
447        // expr (index 2)
448        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
449        {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
450        {LLAMA_GRETYPE_END, 0},
451        // ws (index 3)
452        {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
453        {LLAMA_GRETYPE_END, 0},
454        // term (index 4)
455        {LLAMA_GRETYPE_RULE_REF, /* ident */ 8},
456        {LLAMA_GRETYPE_ALT, 0},
457        {LLAMA_GRETYPE_RULE_REF, /* num */ 9},
458        {LLAMA_GRETYPE_ALT, 0},
459        {LLAMA_GRETYPE_CHAR, '('},
460        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
461        {LLAMA_GRETYPE_RULE_REF, /* expr */ 2},
462        {LLAMA_GRETYPE_CHAR, ')'},
463        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
464        {LLAMA_GRETYPE_END, 0},
465        // root_5 (index 5)
466        {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1},
467        {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5},
468        {LLAMA_GRETYPE_ALT, 0},
469        {LLAMA_GRETYPE_END, 0},
470        // expr_6 (index 6)
471        {LLAMA_GRETYPE_CHAR, '-'},
472        {LLAMA_GRETYPE_CHAR_ALT, '+'},
473        {LLAMA_GRETYPE_CHAR_ALT, '*'},
474        {LLAMA_GRETYPE_CHAR_ALT, '/'},
475        {LLAMA_GRETYPE_RULE_REF, /* term */ 4},
476        {LLAMA_GRETYPE_END, 0},
477        // expr_7 (index 7)
478        {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6},
479        {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7},
480        {LLAMA_GRETYPE_ALT, 0},
481        {LLAMA_GRETYPE_END, 0},
482        // ident (index 8)
483        {LLAMA_GRETYPE_CHAR, 'a'},
484        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
485        {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
486        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
487        {LLAMA_GRETYPE_END, 0},
488        // num (index 9)
489        {LLAMA_GRETYPE_CHAR, '0'},
490        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
491        {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
492        {LLAMA_GRETYPE_RULE_REF, /* ws */ 3},
493        {LLAMA_GRETYPE_END, 0},
494        // ident_10 (index 10)
495        {LLAMA_GRETYPE_CHAR, 'a'},
496        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'},
497        {LLAMA_GRETYPE_CHAR_ALT, '0'},
498        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
499        {LLAMA_GRETYPE_CHAR_ALT, '_'},
500        {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10},
501        {LLAMA_GRETYPE_ALT, 0},
502        {LLAMA_GRETYPE_END, 0},
503        // num_11 (index 11)
504        {LLAMA_GRETYPE_CHAR, '0'},
505        {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'},
506        {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11},
507        {LLAMA_GRETYPE_ALT, 0},
508        {LLAMA_GRETYPE_END, 0},
509        // ws_12 (index 12)
510        {LLAMA_GRETYPE_CHAR, ' '},
511        {LLAMA_GRETYPE_CHAR_ALT, '\t'},
512        {LLAMA_GRETYPE_CHAR_ALT, '\n'},
513        {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12},
514        {LLAMA_GRETYPE_ALT, 0},
515        {LLAMA_GRETYPE_END, 0},
516    });
517
518    // <[1000]> = "<think>"
519    // <[1001]> = "</think>"
520    verify_parsing(R"""(
521        root  ::= <[1000]> !<[1001]> <[1001]>
522    )""", {
523        {"root", 0}
524    }, {
525        // root (index 0)
526        {LLAMA_GRETYPE_TOKEN, 1000},
527        {LLAMA_GRETYPE_TOKEN_NOT, 1001},
528        {LLAMA_GRETYPE_TOKEN, 1001},
529        {LLAMA_GRETYPE_END, 0},
530    });
531
532    return 0;
533}