1#ifdef NDEBUG
  2#undef NDEBUG
  3#endif
  4
  5#include "llama.h"
  6
  7#include "../src/llama-grammar.h"
  8
  9#include <cassert>
 10#include <stdexcept>
 11
 12int main()
 13{
 14    llama_grammar_parser parsed_grammar;
 15
 16    std::vector<std::pair<std::string, uint32_t>> expected = {
 17        {"expr", 2},
 18        {"expr_6", 6},
 19        {"expr_7", 7},
 20        {"ident", 8},
 21        {"ident_10", 10},
 22        {"num", 9},
 23        {"num_11", 11},
 24        {"root", 0},
 25        {"root_1", 1},
 26        {"root_5", 5},
 27        {"term", 4},
 28        {"ws", 3},
 29        {"ws_12", 12},
 30    };
 31
 32    std::vector<std::vector<llama_grammar_element>> expected_rules = {
 33        {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
 34        {
 35            {LLAMA_GRETYPE_RULE_REF, 2},
 36            {LLAMA_GRETYPE_CHAR, 61},
 37            {LLAMA_GRETYPE_RULE_REF, 3},
 38            {LLAMA_GRETYPE_RULE_REF, 4},
 39            {LLAMA_GRETYPE_CHAR, 10},
 40            {LLAMA_GRETYPE_END, 0},
 41        },
 42        {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
 43        {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
 44        {
 45            {LLAMA_GRETYPE_RULE_REF, 8},
 46            {LLAMA_GRETYPE_ALT, 0},
 47            {LLAMA_GRETYPE_RULE_REF, 9},
 48            {LLAMA_GRETYPE_ALT, 0},
 49            {LLAMA_GRETYPE_CHAR, 40},
 50            {LLAMA_GRETYPE_RULE_REF, 3},
 51            {LLAMA_GRETYPE_RULE_REF, 2},
 52            {LLAMA_GRETYPE_CHAR, 41},
 53            {LLAMA_GRETYPE_RULE_REF, 3},
 54            {LLAMA_GRETYPE_END, 0},
 55        },
 56        {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
 57        {
 58            {LLAMA_GRETYPE_CHAR, 45},
 59            {LLAMA_GRETYPE_CHAR_ALT, 43},
 60            {LLAMA_GRETYPE_CHAR_ALT, 42},
 61            {LLAMA_GRETYPE_CHAR_ALT, 47},
 62            {LLAMA_GRETYPE_RULE_REF, 4},
 63            {LLAMA_GRETYPE_END, 0},
 64        },
 65        {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
 66        {
 67            {LLAMA_GRETYPE_CHAR, 97},
 68            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
 69            {LLAMA_GRETYPE_RULE_REF, 10},
 70            {LLAMA_GRETYPE_RULE_REF, 3},
 71            {LLAMA_GRETYPE_END, 0},
 72        },
 73        {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
 74        {
 75            {LLAMA_GRETYPE_CHAR, 97},
 76            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
 77            {LLAMA_GRETYPE_CHAR_ALT, 48},
 78            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
 79            {LLAMA_GRETYPE_CHAR_ALT, 95},
 80            {LLAMA_GRETYPE_RULE_REF, 10},
 81            {LLAMA_GRETYPE_ALT, 0},
 82            {LLAMA_GRETYPE_END, 0},
 83        },
 84        {
 85            {LLAMA_GRETYPE_CHAR, 48},
 86            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
 87            {LLAMA_GRETYPE_RULE_REF, 11},
 88            {LLAMA_GRETYPE_ALT, 0},
 89            {LLAMA_GRETYPE_CHAR, 48},
 90            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
 91            {LLAMA_GRETYPE_END, 0},
 92        },
 93        {
 94            {LLAMA_GRETYPE_CHAR, 32},
 95            {LLAMA_GRETYPE_CHAR_ALT, 9},
 96            {LLAMA_GRETYPE_CHAR_ALT, 10},
 97            {LLAMA_GRETYPE_RULE_REF, 12},
 98            {LLAMA_GRETYPE_ALT, 0},
 99            {LLAMA_GRETYPE_END, 0},
100        },
101    };
102
103    for (auto pair : expected)
104    {
105        parsed_grammar.symbol_ids[pair.first] = pair.second;
106    }
107
108    for (auto rule : expected_rules)
109    {
110        parsed_grammar.rules.emplace_back();
111        for (auto element : rule)
112        {
113            parsed_grammar.rules.back().push_back(element);
114        }
115    }
116
117    std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
118
119    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
120    if (grammar == nullptr) {
121        throw std::runtime_error("Failed to initialize llama_grammar");
122    }
123
124    std::vector<std::vector<llama_grammar_element>> expected_stacks = {
125        {
126            {LLAMA_GRETYPE_RULE_REF, 5},
127            {LLAMA_GRETYPE_CHAR, 61},
128            {LLAMA_GRETYPE_RULE_REF, 7},
129            {LLAMA_GRETYPE_CHAR, 97},
130        },
131        {
132            {LLAMA_GRETYPE_RULE_REF, 5},
133            {LLAMA_GRETYPE_CHAR, 61},
134            {LLAMA_GRETYPE_RULE_REF, 7},
135            {LLAMA_GRETYPE_RULE_REF, 3},
136            {LLAMA_GRETYPE_CHAR, 48},
137        },
138        {
139            {LLAMA_GRETYPE_RULE_REF, 5},
140            {LLAMA_GRETYPE_CHAR, 61},
141            {LLAMA_GRETYPE_RULE_REF, 7},
142            {LLAMA_GRETYPE_RULE_REF, 3},
143            {LLAMA_GRETYPE_CHAR, 48},
144        },
145        {
146            {LLAMA_GRETYPE_RULE_REF, 5},
147            {LLAMA_GRETYPE_CHAR, 61},
148            {LLAMA_GRETYPE_RULE_REF, 7},
149            {LLAMA_GRETYPE_CHAR, 40},
150        },
151        {
152            {LLAMA_GRETYPE_CHAR, 61},
153            {LLAMA_GRETYPE_RULE_REF, 7},
154            {LLAMA_GRETYPE_CHAR, 97},
155        },
156        {
157            {LLAMA_GRETYPE_CHAR, 61},
158            {LLAMA_GRETYPE_RULE_REF, 7},
159            {LLAMA_GRETYPE_RULE_REF, 3},
160            {LLAMA_GRETYPE_CHAR, 48},
161        },
162        {
163            {LLAMA_GRETYPE_CHAR, 61},
164            {LLAMA_GRETYPE_RULE_REF, 7},
165            {LLAMA_GRETYPE_RULE_REF, 3},
166            {LLAMA_GRETYPE_CHAR, 48},
167        },
168        {
169            {LLAMA_GRETYPE_CHAR, 61},
170            {LLAMA_GRETYPE_RULE_REF, 7},
171            {LLAMA_GRETYPE_CHAR, 40},
172        }};
173
174    auto index = 0;
175    for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
176    {
177        // compare stack to expected_stack
178        for (uint32_t i = 0; i < stack.size(); i++)
179        {
180            const llama_grammar_element * element = stack[i];
181            const llama_grammar_element & expected_element = expected_stacks[index][i];
182
183            // pretty print error message before asserting
184            if (expected_element.type != element->type || expected_element.value != element->value)
185            {
186                fprintf(stderr, "index: %d\n", index);
187                fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
188                fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
189                fprintf(stderr, "expected_element != actual_element\n");
190            }
191
192            assert(expected_element.type == element->type && expected_element.value == element->value);
193        }
194        index++;
195    }
196
197    std::vector<llama_grammar_candidate> next_candidates;
198    next_candidates.resize(24);
199
200    for (size_t i = 0; i < 24; ++i)
201    {
202        uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
203        cp[0] = 37 + i;
204        cp[1] = 0;
205        next_candidates[i] = {i, cp, {}, 0};
206    }
207
208    std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
209        {
210            {0, 37},
211            {1, 38},
212            {2, 39},
213            {3, 40},
214            {4, 41},
215            {5, 42},
216            {6, 43},
217            {7, 44},
218            {8, 45},
219            {9, 46},
220            {10, 47},
221            {11, 48},
222            {12, 49},
223            {13, 50},
224            {14, 51},
225            {15, 52},
226            {16, 53},
227            {17, 54},
228            {18, 55},
229            {19, 56},
230            {20, 57},
231            {21, 58},
232            {22, 59},
233            {23, 60},
234        },
235        {
236            {0, 37},
237            {1, 38},
238            {2, 39},
239            {3, 40},
240            {4, 41},
241            {5, 42},
242            {6, 43},
243            {7, 44},
244            {8, 45},
245            {9, 46},
246            {10, 47},
247            {21, 58},
248            {22, 59},
249            {23, 60},
250        },
251        {
252            {0, 37},
253            {1, 38},
254            {2, 39},
255            {3, 40},
256            {4, 41},
257            {5, 42},
258            {6, 43},
259            {7, 44},
260            {8, 45},
261            {9, 46},
262            {10, 47},
263            {21, 58},
264            {22, 59},
265            {23, 60},
266        },
267        {
268            {0, 37},
269            {1, 38},
270            {2, 39},
271            {4, 41},
272            {5, 42},
273            {6, 43},
274            {7, 44},
275            {8, 45},
276            {9, 46},
277            {10, 47},
278            {11, 48},
279            {12, 49},
280            {13, 50},
281            {14, 51},
282            {15, 52},
283            {16, 53},
284            {17, 54},
285            {18, 55},
286            {19, 56},
287            {20, 57},
288            {21, 58},
289            {22, 59},
290            {23, 60},
291        },
292        {
293            {0, 37},
294            {1, 38},
295            {2, 39},
296            {3, 40},
297            {4, 41},
298            {5, 42},
299            {6, 43},
300            {7, 44},
301            {8, 45},
302            {9, 46},
303            {10, 47},
304            {11, 48},
305            {12, 49},
306            {13, 50},
307            {14, 51},
308            {15, 52},
309            {16, 53},
310            {17, 54},
311            {18, 55},
312            {19, 56},
313            {20, 57},
314            {21, 58},
315            {22, 59},
316            {23, 60},
317        },
318        {
319            {0, 37},
320            {1, 38},
321            {2, 39},
322            {3, 40},
323            {4, 41},
324            {5, 42},
325            {6, 43},
326            {7, 44},
327            {8, 45},
328            {9, 46},
329            {10, 47},
330            {21, 58},
331            {22, 59},
332            {23, 60},
333        },
334        {
335            {0, 37},
336            {1, 38},
337            {2, 39},
338            {3, 40},
339            {4, 41},
340            {5, 42},
341            {6, 43},
342            {7, 44},
343            {8, 45},
344            {9, 46},
345            {10, 47},
346            {21, 58},
347            {22, 59},
348            {23, 60},
349        },
350        {
351            {0, 37},
352            {1, 38},
353            {2, 39},
354            {4, 41},
355            {5, 42},
356            {6, 43},
357            {7, 44},
358            {8, 45},
359            {9, 46},
360            {10, 47},
361            {11, 48},
362            {12, 49},
363            {13, 50},
364            {14, 51},
365            {15, 52},
366            {16, 53},
367            {17, 54},
368            {18, 55},
369            {19, 56},
370            {20, 57},
371            {21, 58},
372            {22, 59},
373            {23, 60},
374        },
375    };
376
377    std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
378
379    std::vector<std::vector<llama_grammar_candidate>> all_rejects;
380
381    for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
382    {
383        rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
384        all_rejects.push_back(rejects);
385    }
386
387    index = 0;
388    for (auto rej : all_rejects)
389    {
390        for (uint32_t i = 0; i < rej.size(); i++)
391        {
392            auto element = rej[i];
393            auto expected_element = expected_reject[index][i];
394            assert(element.index == expected_element.first && *element.code_points == expected_element.second);
395        }
396        index++;
397    }
398
399    for (auto &candidate : next_candidates)
400    {
401        delete[] candidate.code_points;
402        candidate.code_points = nullptr;
403    }
404
405    llama_grammar_free_impl(grammar);
406
407    return 0;
408}