1#ifdef NDEBUG
2#undef NDEBUG
3#endif
4
5#include "llama.h"
6
7#include "../src/llama-grammar.h"
8
9#include <cassert>
10#include <stdexcept>
11
12int main()
13{
14 llama_grammar_parser parsed_grammar;
15
16 std::vector<std::pair<std::string, uint32_t>> expected = {
17 {"expr", 2},
18 {"expr_6", 6},
19 {"expr_7", 7},
20 {"ident", 8},
21 {"ident_10", 10},
22 {"num", 9},
23 {"num_11", 11},
24 {"root", 0},
25 {"root_1", 1},
26 {"root_5", 5},
27 {"term", 4},
28 {"ws", 3},
29 {"ws_12", 12},
30 };
31
32 std::vector<std::vector<llama_grammar_element>> expected_rules = {
33 {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
34 {
35 {LLAMA_GRETYPE_RULE_REF, 2},
36 {LLAMA_GRETYPE_CHAR, 61},
37 {LLAMA_GRETYPE_RULE_REF, 3},
38 {LLAMA_GRETYPE_RULE_REF, 4},
39 {LLAMA_GRETYPE_CHAR, 10},
40 {LLAMA_GRETYPE_END, 0},
41 },
42 {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
43 {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
44 {
45 {LLAMA_GRETYPE_RULE_REF, 8},
46 {LLAMA_GRETYPE_ALT, 0},
47 {LLAMA_GRETYPE_RULE_REF, 9},
48 {LLAMA_GRETYPE_ALT, 0},
49 {LLAMA_GRETYPE_CHAR, 40},
50 {LLAMA_GRETYPE_RULE_REF, 3},
51 {LLAMA_GRETYPE_RULE_REF, 2},
52 {LLAMA_GRETYPE_CHAR, 41},
53 {LLAMA_GRETYPE_RULE_REF, 3},
54 {LLAMA_GRETYPE_END, 0},
55 },
56 {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
57 {
58 {LLAMA_GRETYPE_CHAR, 45},
59 {LLAMA_GRETYPE_CHAR_ALT, 43},
60 {LLAMA_GRETYPE_CHAR_ALT, 42},
61 {LLAMA_GRETYPE_CHAR_ALT, 47},
62 {LLAMA_GRETYPE_RULE_REF, 4},
63 {LLAMA_GRETYPE_END, 0},
64 },
65 {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
66 {
67 {LLAMA_GRETYPE_CHAR, 97},
68 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
69 {LLAMA_GRETYPE_RULE_REF, 10},
70 {LLAMA_GRETYPE_RULE_REF, 3},
71 {LLAMA_GRETYPE_END, 0},
72 },
73 {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
74 {
75 {LLAMA_GRETYPE_CHAR, 97},
76 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
77 {LLAMA_GRETYPE_CHAR_ALT, 48},
78 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
79 {LLAMA_GRETYPE_CHAR_ALT, 95},
80 {LLAMA_GRETYPE_RULE_REF, 10},
81 {LLAMA_GRETYPE_ALT, 0},
82 {LLAMA_GRETYPE_END, 0},
83 },
84 {
85 {LLAMA_GRETYPE_CHAR, 48},
86 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
87 {LLAMA_GRETYPE_RULE_REF, 11},
88 {LLAMA_GRETYPE_ALT, 0},
89 {LLAMA_GRETYPE_CHAR, 48},
90 {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
91 {LLAMA_GRETYPE_END, 0},
92 },
93 {
94 {LLAMA_GRETYPE_CHAR, 32},
95 {LLAMA_GRETYPE_CHAR_ALT, 9},
96 {LLAMA_GRETYPE_CHAR_ALT, 10},
97 {LLAMA_GRETYPE_RULE_REF, 12},
98 {LLAMA_GRETYPE_ALT, 0},
99 {LLAMA_GRETYPE_END, 0},
100 },
101 };
102
103 for (auto pair : expected)
104 {
105 parsed_grammar.symbol_ids[pair.first] = pair.second;
106 }
107
108 for (auto rule : expected_rules)
109 {
110 parsed_grammar.rules.emplace_back();
111 for (auto element : rule)
112 {
113 parsed_grammar.rules.back().push_back(element);
114 }
115 }
116
117 std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
118
119 llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
120 if (grammar == nullptr) {
121 throw std::runtime_error("Failed to initialize llama_grammar");
122 }
123
124 std::vector<std::vector<llama_grammar_element>> expected_stacks = {
125 {
126 {LLAMA_GRETYPE_RULE_REF, 5},
127 {LLAMA_GRETYPE_CHAR, 61},
128 {LLAMA_GRETYPE_RULE_REF, 7},
129 {LLAMA_GRETYPE_CHAR, 97},
130 },
131 {
132 {LLAMA_GRETYPE_RULE_REF, 5},
133 {LLAMA_GRETYPE_CHAR, 61},
134 {LLAMA_GRETYPE_RULE_REF, 7},
135 {LLAMA_GRETYPE_RULE_REF, 3},
136 {LLAMA_GRETYPE_CHAR, 48},
137 },
138 {
139 {LLAMA_GRETYPE_RULE_REF, 5},
140 {LLAMA_GRETYPE_CHAR, 61},
141 {LLAMA_GRETYPE_RULE_REF, 7},
142 {LLAMA_GRETYPE_RULE_REF, 3},
143 {LLAMA_GRETYPE_CHAR, 48},
144 },
145 {
146 {LLAMA_GRETYPE_RULE_REF, 5},
147 {LLAMA_GRETYPE_CHAR, 61},
148 {LLAMA_GRETYPE_RULE_REF, 7},
149 {LLAMA_GRETYPE_CHAR, 40},
150 },
151 {
152 {LLAMA_GRETYPE_CHAR, 61},
153 {LLAMA_GRETYPE_RULE_REF, 7},
154 {LLAMA_GRETYPE_CHAR, 97},
155 },
156 {
157 {LLAMA_GRETYPE_CHAR, 61},
158 {LLAMA_GRETYPE_RULE_REF, 7},
159 {LLAMA_GRETYPE_RULE_REF, 3},
160 {LLAMA_GRETYPE_CHAR, 48},
161 },
162 {
163 {LLAMA_GRETYPE_CHAR, 61},
164 {LLAMA_GRETYPE_RULE_REF, 7},
165 {LLAMA_GRETYPE_RULE_REF, 3},
166 {LLAMA_GRETYPE_CHAR, 48},
167 },
168 {
169 {LLAMA_GRETYPE_CHAR, 61},
170 {LLAMA_GRETYPE_RULE_REF, 7},
171 {LLAMA_GRETYPE_CHAR, 40},
172 }};
173
174 auto index = 0;
175 for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
176 {
177 // compare stack to expected_stack
178 for (uint32_t i = 0; i < stack.size(); i++)
179 {
180 const llama_grammar_element * element = stack[i];
181 const llama_grammar_element & expected_element = expected_stacks[index][i];
182
183 // pretty print error message before asserting
184 if (expected_element.type != element->type || expected_element.value != element->value)
185 {
186 fprintf(stderr, "index: %d\n", index);
187 fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
188 fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
189 fprintf(stderr, "expected_element != actual_element\n");
190 }
191
192 assert(expected_element.type == element->type && expected_element.value == element->value);
193 }
194 index++;
195 }
196
197 std::vector<llama_grammar_candidate> next_candidates;
198 next_candidates.resize(24);
199
200 for (size_t i = 0; i < 24; ++i)
201 {
202 uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
203 cp[0] = 37 + i;
204 cp[1] = 0;
205 next_candidates[i] = {i, cp, {}, 0};
206 }
207
208 std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
209 {
210 {0, 37},
211 {1, 38},
212 {2, 39},
213 {3, 40},
214 {4, 41},
215 {5, 42},
216 {6, 43},
217 {7, 44},
218 {8, 45},
219 {9, 46},
220 {10, 47},
221 {11, 48},
222 {12, 49},
223 {13, 50},
224 {14, 51},
225 {15, 52},
226 {16, 53},
227 {17, 54},
228 {18, 55},
229 {19, 56},
230 {20, 57},
231 {21, 58},
232 {22, 59},
233 {23, 60},
234 },
235 {
236 {0, 37},
237 {1, 38},
238 {2, 39},
239 {3, 40},
240 {4, 41},
241 {5, 42},
242 {6, 43},
243 {7, 44},
244 {8, 45},
245 {9, 46},
246 {10, 47},
247 {21, 58},
248 {22, 59},
249 {23, 60},
250 },
251 {
252 {0, 37},
253 {1, 38},
254 {2, 39},
255 {3, 40},
256 {4, 41},
257 {5, 42},
258 {6, 43},
259 {7, 44},
260 {8, 45},
261 {9, 46},
262 {10, 47},
263 {21, 58},
264 {22, 59},
265 {23, 60},
266 },
267 {
268 {0, 37},
269 {1, 38},
270 {2, 39},
271 {4, 41},
272 {5, 42},
273 {6, 43},
274 {7, 44},
275 {8, 45},
276 {9, 46},
277 {10, 47},
278 {11, 48},
279 {12, 49},
280 {13, 50},
281 {14, 51},
282 {15, 52},
283 {16, 53},
284 {17, 54},
285 {18, 55},
286 {19, 56},
287 {20, 57},
288 {21, 58},
289 {22, 59},
290 {23, 60},
291 },
292 {
293 {0, 37},
294 {1, 38},
295 {2, 39},
296 {3, 40},
297 {4, 41},
298 {5, 42},
299 {6, 43},
300 {7, 44},
301 {8, 45},
302 {9, 46},
303 {10, 47},
304 {11, 48},
305 {12, 49},
306 {13, 50},
307 {14, 51},
308 {15, 52},
309 {16, 53},
310 {17, 54},
311 {18, 55},
312 {19, 56},
313 {20, 57},
314 {21, 58},
315 {22, 59},
316 {23, 60},
317 },
318 {
319 {0, 37},
320 {1, 38},
321 {2, 39},
322 {3, 40},
323 {4, 41},
324 {5, 42},
325 {6, 43},
326 {7, 44},
327 {8, 45},
328 {9, 46},
329 {10, 47},
330 {21, 58},
331 {22, 59},
332 {23, 60},
333 },
334 {
335 {0, 37},
336 {1, 38},
337 {2, 39},
338 {3, 40},
339 {4, 41},
340 {5, 42},
341 {6, 43},
342 {7, 44},
343 {8, 45},
344 {9, 46},
345 {10, 47},
346 {21, 58},
347 {22, 59},
348 {23, 60},
349 },
350 {
351 {0, 37},
352 {1, 38},
353 {2, 39},
354 {4, 41},
355 {5, 42},
356 {6, 43},
357 {7, 44},
358 {8, 45},
359 {9, 46},
360 {10, 47},
361 {11, 48},
362 {12, 49},
363 {13, 50},
364 {14, 51},
365 {15, 52},
366 {16, 53},
367 {17, 54},
368 {18, 55},
369 {19, 56},
370 {20, 57},
371 {21, 58},
372 {22, 59},
373 {23, 60},
374 },
375 };
376
377 std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
378
379 std::vector<std::vector<llama_grammar_candidate>> all_rejects;
380
381 for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
382 {
383 rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
384 all_rejects.push_back(rejects);
385 }
386
387 index = 0;
388 for (auto rej : all_rejects)
389 {
390 for (uint32_t i = 0; i < rej.size(); i++)
391 {
392 auto element = rej[i];
393 auto expected_element = expected_reject[index][i];
394 assert(element.index == expected_element.first && *element.code_points == expected_element.second);
395 }
396 index++;
397 }
398
399 for (auto &candidate : next_candidates)
400 {
401 delete[] candidate.code_points;
402 candidate.code_points = nullptr;
403 }
404
405 llama_grammar_free_impl(grammar);
406
407 return 0;
408}