1#ifdef NDEBUG
2# undef NDEBUG
3#endif
4
5#include "sampling.h"
6
7#include <cassert>
8#include <string>
9#include <vector>
10
11static const llama_vocab * vocab;
12
13static bool match_string(const std::string & input, llama_sampler * grammar) {
14 llama_sampler_reset(grammar);
15 auto tokens = common_tokenize(vocab, input, false, false);
16
17 auto n_vocab = llama_vocab_n_tokens(vocab);
18
19 std::vector<llama_token_data> cur;
20 cur.reserve(n_vocab);
21 for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
22 cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
23 }
24 auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
25
26 for (const auto token : tokens) {
27 for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
28 cur[token_id].logit = 0.0f;
29 }
30 llama_sampler_apply(grammar, &tok_arr);
31 if (cur[token].logit < 0.0f) {
32 return false;
33 }
34 llama_sampler_accept(grammar, token);
35 }
36
37 // do we allow EOS at the end? if so the grammar is accepting
38
39 auto tok_eos = llama_vocab_eot(vocab);
40 if (tok_eos == LLAMA_TOKEN_NULL) {
41 tok_eos = llama_vocab_eos(vocab);
42 }
43
44 cur[tok_eos].logit = 0.0f;
45 llama_sampler_apply(grammar, &tok_arr);
46
47 return cur[tok_eos].logit >= 0.0f;
48}
49
50static void test(const std::string & test_desc, const std::string & grammar_str,
51 const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
52 fprintf(stderr, "โซ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
53 fflush(stderr);
54
55 auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
56
57 fprintf(stderr, " ๐ต Valid strings:\n");
58
59 // Passing strings
60 for (const auto & test_string : passing_strings) {
61 fprintf(stderr, " \"%s\" ", test_string.c_str());
62 fflush(stderr);
63
64 bool matched = match_string(test_string, grammar);
65
66 if (!matched) {
67 fprintf(stderr, "โ (failed to match)\n");
68
69 // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
70 // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
71 FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
72 if (grammar_file) {
73 fprintf(grammar_file, "%s", grammar_str.c_str());
74 fclose(grammar_file);
75 }
76
77 // DEBUG: Write the test string to test-grammar-integration.string.txt
78 FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
79 if (string_file) {
80 fprintf(string_file, "%s", test_string.c_str());
81 fclose(string_file);
82 }
83
84 fprintf(stderr,
85 "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
86 "command: ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
87 "test-grammar-integration.string.txt\n\n");
88 } else {
89 fprintf(stdout, "โ
๏ธ\n");
90 }
91
92 assert(matched);
93 }
94
95 fprintf(stderr, " ๐ Invalid strings:\n");
96
97 // Failing strings
98 for (const auto & test_string : failing_strings) {
99 fprintf(stderr, " \"%s\" ", test_string.c_str());
100 fflush(stderr);
101
102 bool matched = match_string(test_string, grammar);
103
104 if (matched) {
105 fprintf(stderr, "โ (incorrectly matched)\n");
106 } else {
107 fprintf(stdout, "โ
๏ธ\n");
108 }
109 assert(!matched);
110 }
111
112 llama_sampler_free(grammar);
113}
114
115static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
116 const std::vector<std::string> & passing_strings,
117 const std::vector<std::string> & failing_strings) {
118 test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
119}
120
121static void test_schema(const std::string & test_desc, const std::string & schema_str,
122 const std::vector<std::string> & passing_strings,
123 const std::vector<std::string> & failing_strings) {
124 test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
125 failing_strings);
126}
127
128static void test_simple_grammar() {
129 test_schema("min 0",
130 R"""({
131 "type": "integer",
132 "minimum": 0
133 })""",
134 // Passing strings
135 {
136 "0",
137 "10",
138 "12",
139 "10000",
140 },
141 // Failing strings
142 {
143 "-1",
144 "-10",
145 "-10000",
146 "-100000000000000000000000000000000",
147 // "100000000000000000000000000000000",
148 "00",
149 "01",
150 "-0",
151 });
152 test_schema("min 2",
153 // Schema
154 R"""({
155 "type": "integer",
156 "minimum": 2
157 })""",
158 // Passing strings
159 {
160 "2",
161 "3",
162 "4",
163 "10",
164 "20",
165 "1234567890000000",
166 },
167 // Failing strings
168 {
169 "0", "1", "-1", "-100", "0", "1", "01", "02",
170 // "12345678900000000",
171 });
172 test_schema("min 456",
173 R"""({
174 "type": "integer",
175 "minimum": 456
176 })""",
177 // Passing strings
178 {
179 "456",
180 "4560",
181 "457",
182 "460",
183 "500",
184 },
185 // Failing strings
186 {
187 "455",
188 "356",
189 "50",
190 "050",
191 "-1",
192 "-456",
193 });
194 test_schema("min -123",
195 R"""({
196 "type": "integer",
197 "minimum": -123
198 })""",
199 // Passing strings
200 {
201 "-123",
202 "-122",
203 "-11",
204 "-1",
205 "0",
206 "1",
207 "123",
208 "1234",
209 "2345",
210 },
211 // Failing strings
212 {
213 "-1234",
214 "-124",
215 });
216
217 test_schema("max 9999",
218 // Schema
219 R"""({
220 "type": "integer",
221 "maximum": 9999
222 })""",
223 // Passing strings
224 {
225 "-99999",
226 "0",
227 "9999",
228 },
229 // Failing strings
230 {
231 "10000",
232 "99991",
233 });
234 test_schema("max -9999",
235 // Schema
236 R"""({
237 "type": "integer",
238 "maximum": -9999
239 })""",
240 // Passing strings
241 {
242 "-10000",
243 "-9999",
244 },
245 // Failing strings
246 {
247 "-9998",
248 "0",
249 "9999",
250 });
251 test_schema("min 5 max 30",
252 // Schema
253 R"""({
254 "type": "integer",
255 "minimum": 5,
256 "maximum": 30
257 })""",
258 // Passing strings
259 {
260 "5",
261 "10",
262 "30",
263 },
264 // Failing strings
265 {
266 "05",
267 "4",
268 "-1",
269 "31",
270 "123",
271 "0123",
272 });
273 test_schema("min -1 max 1",
274 R"""({
275 "type": "integer",
276 "minimum": -1,
277 "maximum": 1
278 })""",
279 // Passing strings
280 {
281 "-1",
282 "0",
283 "1",
284 },
285 // Failing strings
286 {
287 "-11",
288 "-10",
289 "-2",
290 "2",
291 "10",
292 "11",
293 });
294 test_schema("min -123 max 42",
295 R"""({
296 "type": "integer",
297 "minimum": -123,
298 "maximum": 42
299 })""",
300 // Passing strings
301 {
302 "-123",
303 "-122",
304 "-13",
305 "-11",
306 "-2",
307 "-1",
308 "0",
309 "1",
310 "5",
311 "10",
312 "39",
313 "40",
314 "42",
315 },
316 // Failing strings
317 {
318 "-0123",
319 "-124",
320 "-1123",
321 "-200",
322 "43",
323 "123",
324 "0123",
325 });
326 test_schema("exclusive min / max",
327 // Schema
328 R"""({
329 "type": "integer",
330 "exclusiveMinimum": 0,
331 "exclusiveMaximum": 10000
332 })""",
333 // Passing strings
334 {
335 "1",
336 "9999",
337 },
338 // Failing strings
339 {
340 "0",
341 "01",
342 "10000",
343 "99999",
344 });
345
346 // Test case for a simple grammar
347 test_grammar("simple grammar",
348 R"""(
349 start: expr
350 expr: term ("+" term)*
351 term: number
352 number: /[0-9]+/ )""",
353 // Passing strings
354 {
355 "42",
356 "1+2+3+4+5",
357 "123+456",
358 },
359 // Failing strings
360 {
361 "+",
362 "/ 3",
363 "1+2+3+4+5+",
364 "12a45",
365 });
366}
367
368static void test_complex_grammar() {
369 // Test case for a more complex grammar, with both failure strings and success strings
370 test_grammar("medium complexity grammar",
371 // Grammar
372 R"""(
373 start: expression
374 expression: term ws (("+"|"-") ws term)*
375 term: factor ws (("*"|"/") ws factor)*
376 factor: number | variable | "(" expression ")" | function-call
377 number: /[0-9]+/
378 variable: /[a-zA-Z_][a-zA-Z0-9_]*/
379 function-call: variable ws "(" (expression ("," ws expression)*)? ")"
380 ws: /[ \t\n\r]?/ )""",
381 // Passing strings
382 { "42",
383 "1*2*3*4*5",
384 "x",
385 "x+10",
386 "x1+y2",
387 "(a+b)*(c-d)",
388 "func()",
389 "func(x,y+2)",
390 "a*(b+c)-d/e",
391 "f(g(x),h(y,z))",
392 "x + 10",
393 "x1 + y2",
394 "(a + b) * (c - d)",
395 "func()",
396 "func(x, y + 2)",
397 "a * (b + c) - d / e",
398 "f(g(x), h(y, z))",
399 "123+456",
400 "123*456*789-123/456+789*123",
401 "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
402 // Failing strings
403 {
404 "+",
405 "/ 3x",
406 "x + + y",
407 "a * / b",
408 "func(,)",
409 "func(x y)",
410 "(a + b",
411 "x + y)",
412 "a + b * (c - d",
413 "42 +",
414 "x +",
415 "x + 10 +",
416 "(a + b) * (c - d",
417 "func(",
418 "func(x, y + 2",
419 "a * (b + c) - d /",
420 "f(g(x), h(y, z)",
421 "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
422 });
423}
424
425static void test_special_chars() {
426 // A collection of tests to exercise special characters such as "."
427 test_grammar("special characters",
428 // Grammar
429 R"""(
430 start: /.../ "abc" /.../
431 )""",
432 // Passing strings
433 { "abcabcabc", "aaaabcccc",
434 // NOTE: Also ensures that multi-byte characters still count as a single character
435 "๐ต๐ โ
abcโ๐ ๐ต" },
436 // Failing strings
437 { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "๐ต๐ โ
โabcโโ
๐ ๐ต", "๐ต๐ abc๐ ๐ต" });
438}
439
440static void test_quantifiers() {
441 // A collection of tests to exercise * + and ? quantifiers
442
443 test_grammar(
444 "* quantifier",
445 // Grammar
446 R"""(start: "a"*)""",
447 // Passing strings
448 { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
449 // Failing strings
450 { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
451 test_grammar(
452 "+ quantifier",
453 // Grammar
454 R"""(start: "a"+)""",
455 // Passing strings
456 { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
457 // Failing strings
458 { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
459 test_grammar("? quantifier",
460 // Grammar
461 R"""(start: "a"?)""",
462 // Passing strings
463 { "", "a" },
464 // Failing strings
465 {
466 "b",
467 "ab",
468 "aa",
469 "ba",
470 });
471 test_grammar("mixed quantifiers",
472 // Grammar
473 R"""(
474 start: cons+ vowel* cons? (vowel cons)*
475 vowel: /[aeiouy]/
476 cons: /[bcdfghjklmnpqrstvwxyz]/
477 )""",
478 // Passing strings
479 {
480 "yes",
481 "no",
482 "noyes",
483 "crwth",
484 "four",
485 "bryyyy",
486 },
487 // Failing strings
488 {
489 "yess",
490 "yesno",
491 "forty",
492 "catyyy",
493 });
494 test_grammar("simple exact repetition",
495 // Grammar
496 R"""(
497 start: /[ab]{4}/
498 )""",
499 // Passing strings
500 {
501 "aaaa",
502 "bbbb",
503 "abab",
504 },
505 // Failing strings
506 {
507 "a",
508 "b",
509 "aaaaa",
510 });
511 test_grammar("simple min repetition",
512 // Grammar
513 R"""(
514 start: /[ab]{4,}/
515 )""",
516 // Passing strings
517 {
518 "aaaa",
519 "aaaaab",
520 "bbbb",
521 "ababab",
522 },
523 // Failing strings
524 {
525 "",
526 "aba",
527 });
528 test_grammar("simple max repetition",
529 // Grammar
530 R"""(
531 start: /[ab]{0,4}/
532 )""",
533 // Passing strings
534 {
535 "",
536 "a",
537 "aa",
538 "aaa",
539 "aaab",
540 },
541 // Failing strings
542 {
543 "aaaaa",
544 });
545 // test_grammar("min / max repetition",
546 // // Grammar
547 // R"""(
548 // start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
549 // )""",
550 // // Passing strings
551 // {
552 // "0xFF 0x12 0xAB",
553 // "0xFF 0x12 0xAB 0x00 0x00",
554 // },
555 // // Failing strings
556 // {
557 // "",
558 // "0xFF",
559 // "0xFF 0x12",
560 // "0xFF 0x12 0xAB 0x00 0x00 0x00",
561 // });
562}
563
564static void test_json_schema() {
565 // Note that this is similar to the regular grammar tests,
566 // but we convert each json schema to a grammar before parsing.
567 // Otherwise, this test structure is the same.
568
569 test_schema("empty schema (object)",
570 // Schema
571 R"""(
572 {"type":"object"}
573 )""",
574 // Passing strings
575 {
576 R"""({})""",
577 R"""({"foo": "bar"})""",
578 },
579 // Failing strings
580 {
581 "",
582 "[]",
583 "null",
584 R"""("")""",
585 "true",
586 });
587
588 test_schema(
589 "exotic formats (list)",
590 // Schema
591 R"""({
592 "items": [
593 { "format": "date" },
594 { "format": "uuid" },
595 { "format": "time" },
596 { "format": "date-time" }
597 ]
598 })""",
599 // Passing strings
600 {
601 // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
602 // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
603 R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
604 //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
605 //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
606 },
607 // Failing strings
608 {
609 R"""(["foo", "bar"])""",
610 R"""(["12345678-1234-1234-1234-1234567890ab"])""",
611 });
612
613 test_schema("string",
614 // Schema
615 R"""({
616 "type": "string"
617 })""",
618 // Passing strings
619 {
620 R"""("foo")""",
621 R"""("bar")""",
622 R"""("")""",
623 },
624 // Failing strings
625 {
626 R"""({})""",
627 R"""("foo": "bar")""",
628 });
629
630 test_schema("string w/ min length 1",
631 // Schema
632 R"""({
633 "type": "string",
634 "minLength": 1
635 })""",
636 // Passing strings
637 {
638 R"""("foo")""",
639 R"""("bar")""",
640 },
641 // Failing strings
642 {
643 R"""("")""",
644 R"""({})""",
645 R"""("foo": "bar")""",
646 });
647
648 test_schema("string w/ min length 3",
649 // Schema
650 R"""({
651 "type": "string",
652 "minLength": 3
653 })""",
654 // Passing strings
655 {
656 R"""("foo")""",
657 R"""("bar")""",
658 R"""("foobar")""",
659 },
660 // Failing strings
661 {
662 R"""("")""",
663 R"""("f")""",
664 R"""("fo")""",
665 });
666
667 test_schema("string w/ max length",
668 // Schema
669 R"""({
670 "type": "string",
671 "maxLength": 3
672 })""",
673 // Passing strings
674 {
675 R"""("foo")""",
676 R"""("bar")""",
677 R"""("")""",
678 R"""("f")""",
679 R"""("fo")""",
680 },
681 // Failing strings
682 {
683 R"""("foobar")""",
684 });
685
686 test_schema("string w/ min & max length",
687 // Schema
688 R"""({
689 "type": "string",
690 "minLength": 1,
691 "maxLength": 4
692 })""",
693 // Passing strings
694 {
695 R"""("foo")""",
696 R"""("bar")""",
697 R"""("f")""",
698 R"""("barf")""",
699 },
700 // Failing strings
701 {
702 R"""("")""",
703 R"""("barfo")""",
704 R"""("foobar")""",
705 });
706
707 test_schema("boolean",
708 // Schema
709 R"""({
710 "type": "boolean"
711 })""",
712 // Passing strings
713 {
714 "true",
715 "false",
716 },
717 // Failing strings
718 {
719 R"""("")""",
720 R"""("true")""",
721 R"""(True)""",
722 R"""(FALSE)""",
723 });
724
725 test_schema("integer",
726 // Schema
727 R"""({
728 "type": "integer"
729 })""",
730 // Passing strings
731 {
732 R"""(0)""",
733 R"""(12345)""",
734 R"""(1234567890123456)""",
735 },
736 // Failing strings
737 {
738 R"""()""",
739 R"""(01)""",
740 R"""(007)""",
741 R"""(12345678901234567 )""",
742 });
743
744 test_schema("string const",
745 // Schema
746 R"""({
747 "const": "foo"
748 })""",
749 // Passing strings
750 {
751 R"""("foo")""",
752 },
753 // Failing strings
754 {
755 R"""(foo)""",
756 R"""("bar")""",
757 });
758
759 test_schema("non-string const",
760 // Schema
761 R"""({
762 "const": true
763 })""",
764 // Passing strings
765 {
766 R"""(true)""",
767 },
768 // Failing strings
769 {
770 R"""()""",
771 R"""(foo)""",
772 R"""("true")""",
773 });
774
775 test_schema("non-string const",
776 // Schema
777 R"""({
778 "enum": ["red", "amber", "green", null, 42, ["foo"]]
779 })""",
780 // Passing strings
781 {
782 R"""("red")""",
783 R"""(null)""",
784 R"""(42)""",
785 R"""(["foo"])""",
786 },
787 // Failing strings
788 {
789 R"""()""",
790 R"""(420)""",
791 R"""(true)""",
792 R"""(foo)""",
793 });
794
795 test_schema("simple pattern",
796 // Schema
797 R"""({
798 "pattern": "^[a-zA-Z0-9_-]*$"
799 })""",
800 // Passing strings
801 {
802 R"""("")""",
803 R"""("He_llo-12")""",
804 },
805 // Failing strings
806 {
807 R"""("!")""",
808 R"""("Hello World")""",
809 });
810
811 test_schema("pattern with escapes",
812 // Schema
813 R"""({
814 "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
815 })""",
816 // Passing strings
817 {
818 R"""("a^$.[]()|{}*+?b")""",
819 },
820 // Failing strings
821 {
822 R"""("ab")""",
823 });
824
825 test_schema("",
826 // Schema
827 R"""(
828 {
829 "type": ["array", "null"],
830 "items": { "type": "string" }
831 }
832 )""",
833 // Passing strings
834 {
835 "null",
836 "[]",
837 "[\"123\"]",
838 "[\"foo\", \"bar\"]",
839 },
840 // Failing strings
841 {
842 "",
843 "[123]",
844 "\"foo\"",
845 "[\"foo\", 42]",
846 });
847
848 test_schema("min+max items",
849 // Schema
850 R"""({
851 "items": {
852 "type": ["number", "integer"]
853 },
854 "minItems": 3,
855 "maxItems": 5
856 })""",
857 // Passing strings
858 {
859 R"""([1, 2, 3])""",
860 R"""([1, 2, 3, 4])""",
861 R"""([1, 2, 3, 4, 5])""",
862 // this is in fact correct; keyword do not apply if the type is wrong
863 R"""(1)""",
864 },
865 // Failing strings
866 {
867 R"""([1, 2])""",
868 R"""([1, 2, 3, 4, 5, 6])""",
869 });
870
871 // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
872 test_schema("object properties",
873 // Schema
874 R"""({
875 "type": "object",
876 "properties": {
877 "number": { "type": "number" },
878 "street_name": { "type": "string" },
879 "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
880 },
881 "additionalProperties": false
882 })""",
883 // Passing strings
884 {
885 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
886 // "By default, leaving out properties is valid"
887 R"""({ "street_name": "Pennsylvania" })""",
888 R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
889 // "By extension, even an empty object is valid"
890 R"""({})""",
891 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
892 },
893 // Failing strings
894 {
895 // Change datatype from number to string
896 R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
897 // Reorder properties
898 R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
899 // Reorder properties
900 R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
901 // Additional properties set to false
902 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
903
904 });
905
906 test_schema("additional properties can't override other properties",
907 R"""({
908 "properties": {
909 "a": {"type": "integer"},
910 "b": {"type": "integer"}
911 },
912 "additionalProperties": true
913 })""",
914 // Passing strings
915 {
916 R"""({"a": 42})""",
917 R"""({"c": ""})""",
918 R"""({"a": 42, "c": ""})""",
919 R"""({"a_": ""})""",
920 },
921 // Failing strings
922 {
923 R"""()""",
924 R"""({"a": ""})""",
925 R"""({"a": "", "b": ""})""",
926 });
927
928 // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
929 test_schema("object properties, additionalProperties: true",
930 // Schema
931 R"""({
932 "type": "object",
933 "properties": {
934 "number": { "type": "number" },
935 "street_name": { "type": "string" },
936 "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
937 },
938 "additionalProperties": true
939 })""",
940 // Passing strings
941 {
942 // "By extension, even an empty object is valid"
943 R"""({})""",
944 R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
945 // "By default, leaving out properties is valid"
946 R"""({ "street_name": "Pennsylvania" })""",
947 R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
948 // "By default, providing additional properties is valid"
949 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
950 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
951 },
952 // Failing strings
953 {
954 // Change datatype from number to string
955 R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
956 // Reorder properties
957 R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
958 });
959
960 // Additional properties: false
961 test_schema(
962 "required + optional props each in original order",
963 // Schema
964 R"""({
965 "type": "object",
966 "properties": {
967 "number": { "type": "number" },
968 "street_name": { "type": "string" },
969 "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
970 },
971 "additionalProperties": false
972 })""",
973 // Passing strings
974 {
975 R"""({ "street_name": "Pennsylvania" })""",
976 R"""({ "number": 1600, "street_type":"Avenue"})""",
977 R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
978 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
979 // Spaces are permitted around enum values
980 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
981 },
982 // Failing strings
983 {
984 // Reorder properties
985 R"""({ "street_type": "Avenue", "number": 1600 })""",
986 // Add "direction"
987 R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
988 });
989
990 test_schema("required + optional props each in original order",
991 // Schema
992 R"""({
993 "properties": {
994 "b": {"type": "string"},
995 "a": {"type": "string"},
996 "d": {"type": "string"},
997 "c": {"type": "string"}
998 },
999 "required": ["a", "b"],
1000 "additionalProperties": false
1001 })""",
1002 // Passing strings
1003 {
1004 R"""({"b": "foo", "a": "bar"})""",
1005 R"""({"b":"foo","a":"bar","d":"qux"})""",
1006 R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
1007 },
1008 // Failing strings
1009 {
1010 R"""({"a": "foo", "b": "bar"})""",
1011 R"""({"b": "bar"})""",
1012 R"""({"a": "foo", "c": "baz"})""",
1013 R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
1014 });
1015
1016 // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
1017 test_schema(
1018 "required props",
1019 // Schema
1020 R"""({
1021 "$schema": "https://json-schema.org/draft/2020-12/schema",
1022 "$id": "https://example.com/product.schema.json",
1023 "title": "Product",
1024 "description": "A product from Acme's catalog",
1025 "type": "object",
1026 "properties": {
1027 "productId": {
1028 "description": "The unique identifier for a product",
1029 "type": "integer"
1030 },
1031 "productName": {
1032 "description": "Name of the product",
1033 "type": "string"
1034 },
1035 "price": {
1036 "description": "The price of the product",
1037 "type": "number",
1038 "exclusiveMinimum": 0
1039 },
1040 "tags": {
1041 "description": "Tags for the product",
1042 "type": "array",
1043 "items": {
1044 "type": "string"
1045 },
1046 "minItems": 1,
1047 "DISABLED_uniqueItems": true
1048 },
1049 "dimensions": {
1050 "type": "object",
1051 "properties": {
1052 "length": {
1053 "type": "number"
1054 },
1055 "width": {
1056 "type": "number"
1057 },
1058 "height": {
1059 "type": "number"
1060 }
1061 },
1062 "required": [ "length", "width", "height" ]
1063 }
1064 },
1065 "required": [ "productId", "productName", "price" ]
1066 })""",
1067 // Passing strings
1068 {
1069 R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
1070 R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
1071 R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
1072 },
1073 // Failing strings
1074 {
1075 R"""({})""", // Missing all required properties
1076 R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties
1077 // `exclusiveMinimum` is OK for llg
1078 R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
1079 R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price)
1080 R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId)
1081 R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1
1082 R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order
1083 // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
1084 // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
1085 });
1086}
1087
1088static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
1089 auto n_vocab = tok_arr.size;
1090
1091 tok_arr.selected = -1;
1092 tok_arr.sorted = false;
1093 for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1094 tok_arr.data[token_id].id = token_id;
1095 tok_arr.data[token_id].logit = 0.0f;
1096 }
1097
1098 tok_arr.data[selected].logit = 100.0f;
1099}
1100
1101static void test_sampler_chain(void) {
1102 auto sparams = llama_sampler_chain_default_params();
1103 sparams.no_perf = false;
1104 llama_sampler * sampler = llama_sampler_chain_init(sparams);
1105
1106 const auto grammar_data = R"(%llguidance {}
1107start: /[A-Z ]*/)";
1108
1109 llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
1110 llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
1111
1112 auto input = "ALL YOUR BASE ARE BELONG TO US";
1113 auto tokens = common_tokenize(vocab, input, false, false);
1114
1115 auto n_vocab = llama_vocab_n_tokens(vocab);
1116
1117 std::vector<llama_token_data> cur;
1118 cur.reserve(n_vocab);
1119 for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1120 cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
1121 }
1122 auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
1123
1124 for (const auto token : tokens) {
1125 one_hot(tok_arr, token);
1126
1127 fprintf(stderr, "applying token: %d\n", token);
1128 llama_sampler_apply(sampler, &tok_arr);
1129
1130 auto idx = tok_arr.selected;
1131 fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
1132 assert(cur[tok_arr.selected].id == token);
1133 llama_sampler_accept(sampler, token);
1134 }
1135
1136 auto tok_eos = llama_vocab_eot(vocab);
1137 if (tok_eos == LLAMA_TOKEN_NULL) {
1138 tok_eos = llama_vocab_eos(vocab);
1139 }
1140
1141 one_hot(tok_arr, tok_eos);
1142
1143 llama_sampler_apply(sampler, &tok_arr);
1144 assert(cur[tok_arr.selected].id == tok_eos);
1145}
1146
1147int main(int argc, const char ** argv) {
1148 fprintf(stdout, "Running llguidance integration tests...\n");
1149
1150 if (argc != 2) {
1151 fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
1152 return 1;
1153 }
1154
1155 const char * vocab_file = argv[1];
1156
1157 fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
1158
1159 llama_model * model;
1160 llama_context * ctx;
1161
1162 llama_backend_init();
1163
1164 // load the vocab
1165 {
1166 auto mparams = llama_model_default_params();
1167
1168 mparams.vocab_only = true;
1169
1170 model = llama_model_load_from_file(vocab_file, mparams);
1171
1172 if (model == NULL) {
1173 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
1174 return 1;
1175 }
1176
1177 // needed?
1178 auto cparams = llama_context_default_params();
1179
1180 ctx = llama_init_from_model(model, cparams);
1181
1182 if (ctx == NULL) {
1183 fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
1184 llama_model_free(model);
1185 return 1;
1186 }
1187 }
1188
1189 vocab = llama_model_get_vocab(model);
1190
1191 test_simple_grammar();
1192 test_complex_grammar();
1193 test_special_chars();
1194 test_quantifiers();
1195 test_json_schema();
1196
1197 test_sampler_chain();
1198
1199 llama_free(ctx);
1200 llama_model_free(model);
1201
1202 fprintf(stdout, "All tests passed.\n");
1203 return 0;
1204}