1#ifdef NDEBUG
   2#    undef NDEBUG
   3#endif
   4
   5#include "sampling.h"
   6
   7#include <cassert>
   8#include <string>
   9#include <vector>
  10
  11static const llama_vocab * vocab;
  12
  13static bool match_string(const std::string & input, llama_sampler * grammar) {
  14    llama_sampler_reset(grammar);
  15    auto tokens = common_tokenize(vocab, input, false, false);
  16
  17    auto n_vocab = llama_vocab_n_tokens(vocab);
  18
  19    std::vector<llama_token_data> cur;
  20    cur.reserve(n_vocab);
  21    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  22        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
  23    }
  24    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
  25
  26    for (const auto token : tokens) {
  27        for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  28            cur[token_id].logit = 0.0f;
  29        }
  30        llama_sampler_apply(grammar, &tok_arr);
  31        if (cur[token].logit < 0.0f) {
  32            return false;
  33        }
  34        llama_sampler_accept(grammar, token);
  35    }
  36
  37    // do we allow EOS at the end? if so the grammar is accepting
  38
  39    auto tok_eos = llama_vocab_eot(vocab);
  40    if (tok_eos == LLAMA_TOKEN_NULL) {
  41        tok_eos = llama_vocab_eos(vocab);
  42    }
  43
  44    cur[tok_eos].logit = 0.0f;
  45    llama_sampler_apply(grammar, &tok_arr);
  46
  47    return cur[tok_eos].logit >= 0.0f;
  48}
  49
  50static void test(const std::string & test_desc, const std::string & grammar_str,
  51                 const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
  52    fprintf(stderr, "โšซ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
  53    fflush(stderr);
  54
  55    auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
  56
  57    fprintf(stderr, "  ๐Ÿ”ต Valid strings:\n");
  58
  59    // Passing strings
  60    for (const auto & test_string : passing_strings) {
  61        fprintf(stderr, "    \"%s\" ", test_string.c_str());
  62        fflush(stderr);
  63
  64        bool matched = match_string(test_string, grammar);
  65
  66        if (!matched) {
  67            fprintf(stderr, "โŒ (failed to match)\n");
  68
  69            // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
  70            // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
  71            FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
  72            if (grammar_file) {
  73                fprintf(grammar_file, "%s", grammar_str.c_str());
  74                fclose(grammar_file);
  75            }
  76
  77            // DEBUG: Write the test string to test-grammar-integration.string.txt
  78            FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
  79            if (string_file) {
  80                fprintf(string_file, "%s", test_string.c_str());
  81                fclose(string_file);
  82            }
  83
  84            fprintf(stderr,
  85                    "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
  86                    "command:     ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
  87                    "test-grammar-integration.string.txt\n\n");
  88        } else {
  89            fprintf(stdout, "โœ…๏ธŽ\n");
  90        }
  91
  92        assert(matched);
  93    }
  94
  95    fprintf(stderr, "  ๐ŸŸ  Invalid strings:\n");
  96
  97    // Failing strings
  98    for (const auto & test_string : failing_strings) {
  99        fprintf(stderr, "    \"%s\" ", test_string.c_str());
 100        fflush(stderr);
 101
 102        bool matched = match_string(test_string, grammar);
 103
 104        if (matched) {
 105            fprintf(stderr, "โŒ (incorrectly matched)\n");
 106        } else {
 107            fprintf(stdout, "โœ…๏ธŽ\n");
 108        }
 109        assert(!matched);
 110    }
 111
 112    llama_sampler_free(grammar);
 113}
 114
 115static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
 116                         const std::vector<std::string> & passing_strings,
 117                         const std::vector<std::string> & failing_strings) {
 118    test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
 119}
 120
 121static void test_schema(const std::string & test_desc, const std::string & schema_str,
 122                        const std::vector<std::string> & passing_strings,
 123                        const std::vector<std::string> & failing_strings) {
 124    test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
 125         failing_strings);
 126}
 127
 128static void test_simple_grammar() {
 129    test_schema("min 0",
 130                R"""({
 131            "type": "integer",
 132            "minimum": 0
 133        })""",
 134                // Passing strings
 135                {
 136                    "0",
 137                    "10",
 138                    "12",
 139                    "10000",
 140                },
 141                // Failing strings
 142                {
 143                    "-1",
 144                    "-10",
 145                    "-10000",
 146                    "-100000000000000000000000000000000",
 147                    // "100000000000000000000000000000000",
 148                    "00",
 149                    "01",
 150                    "-0",
 151                });
 152    test_schema("min 2",
 153                // Schema
 154                R"""({
 155            "type": "integer",
 156            "minimum": 2
 157        })""",
 158                // Passing strings
 159                {
 160                    "2",
 161                    "3",
 162                    "4",
 163                    "10",
 164                    "20",
 165                    "1234567890000000",
 166                },
 167                // Failing strings
 168                {
 169                    "0", "1", "-1", "-100", "0", "1", "01", "02",
 170                    // "12345678900000000",
 171                });
 172    test_schema("min 456",
 173                R"""({
 174            "type": "integer",
 175            "minimum": 456
 176        })""",
 177                // Passing strings
 178                {
 179                    "456",
 180                    "4560",
 181                    "457",
 182                    "460",
 183                    "500",
 184                },
 185                // Failing strings
 186                {
 187                    "455",
 188                    "356",
 189                    "50",
 190                    "050",
 191                    "-1",
 192                    "-456",
 193                });
 194    test_schema("min -123",
 195                R"""({
 196            "type": "integer",
 197            "minimum": -123
 198        })""",
 199                // Passing strings
 200                {
 201                    "-123",
 202                    "-122",
 203                    "-11",
 204                    "-1",
 205                    "0",
 206                    "1",
 207                    "123",
 208                    "1234",
 209                    "2345",
 210                },
 211                // Failing strings
 212                {
 213                    "-1234",
 214                    "-124",
 215                });
 216
 217    test_schema("max 9999",
 218                // Schema
 219                R"""({
 220            "type": "integer",
 221            "maximum": 9999
 222        })""",
 223                // Passing strings
 224                {
 225                    "-99999",
 226                    "0",
 227                    "9999",
 228                },
 229                // Failing strings
 230                {
 231                    "10000",
 232                    "99991",
 233                });
 234    test_schema("max -9999",
 235                // Schema
 236                R"""({
 237            "type": "integer",
 238            "maximum": -9999
 239        })""",
 240                // Passing strings
 241                {
 242                    "-10000",
 243                    "-9999",
 244                },
 245                // Failing strings
 246                {
 247                    "-9998",
 248                    "0",
 249                    "9999",
 250                });
 251    test_schema("min 5 max 30",
 252                // Schema
 253                R"""({
 254            "type": "integer",
 255            "minimum": 5,
 256            "maximum": 30
 257        })""",
 258                // Passing strings
 259                {
 260                    "5",
 261                    "10",
 262                    "30",
 263                },
 264                // Failing strings
 265                {
 266                    "05",
 267                    "4",
 268                    "-1",
 269                    "31",
 270                    "123",
 271                    "0123",
 272                });
 273    test_schema("min -1 max 1",
 274                R"""({
 275            "type": "integer",
 276            "minimum": -1,
 277            "maximum": 1
 278        })""",
 279                // Passing strings
 280                {
 281                    "-1",
 282                    "0",
 283                    "1",
 284                },
 285                // Failing strings
 286                {
 287                    "-11",
 288                    "-10",
 289                    "-2",
 290                    "2",
 291                    "10",
 292                    "11",
 293                });
 294    test_schema("min -123 max 42",
 295                R"""({
 296            "type": "integer",
 297            "minimum": -123,
 298            "maximum": 42
 299        })""",
 300                // Passing strings
 301                {
 302                    "-123",
 303                    "-122",
 304                    "-13",
 305                    "-11",
 306                    "-2",
 307                    "-1",
 308                    "0",
 309                    "1",
 310                    "5",
 311                    "10",
 312                    "39",
 313                    "40",
 314                    "42",
 315                },
 316                // Failing strings
 317                {
 318                    "-0123",
 319                    "-124",
 320                    "-1123",
 321                    "-200",
 322                    "43",
 323                    "123",
 324                    "0123",
 325                });
 326    test_schema("exclusive min / max",
 327                // Schema
 328                R"""({
 329            "type": "integer",
 330            "exclusiveMinimum": 0,
 331            "exclusiveMaximum": 10000
 332        })""",
 333                // Passing strings
 334                {
 335                    "1",
 336                    "9999",
 337                },
 338                // Failing strings
 339                {
 340                    "0",
 341                    "01",
 342                    "10000",
 343                    "99999",
 344                });
 345
 346    // Test case for a simple grammar
 347    test_grammar("simple grammar",
 348                 R"""(
 349            start: expr
 350            expr: term ("+" term)*
 351            term: number
 352            number: /[0-9]+/ )""",
 353                 // Passing strings
 354                 {
 355                     "42",
 356                     "1+2+3+4+5",
 357                     "123+456",
 358                 },
 359                 // Failing strings
 360                 {
 361                     "+",
 362                     "/ 3",
 363                     "1+2+3+4+5+",
 364                     "12a45",
 365                 });
 366}
 367
 368static void test_complex_grammar() {
 369    // Test case for a more complex grammar, with both failure strings and success strings
 370    test_grammar("medium complexity grammar",
 371                 // Grammar
 372                 R"""(
 373            start: expression
 374            expression: term ws (("+"|"-") ws term)*
 375            term: factor ws (("*"|"/") ws factor)*
 376            factor: number | variable | "(" expression ")" | function-call
 377            number: /[0-9]+/
 378            variable: /[a-zA-Z_][a-zA-Z0-9_]*/
 379            function-call: variable ws "(" (expression ("," ws expression)*)? ")"
 380            ws: /[ \t\n\r]?/ )""",
 381                 // Passing strings
 382                 { "42",
 383                   "1*2*3*4*5",
 384                   "x",
 385                   "x+10",
 386                   "x1+y2",
 387                   "(a+b)*(c-d)",
 388                   "func()",
 389                   "func(x,y+2)",
 390                   "a*(b+c)-d/e",
 391                   "f(g(x),h(y,z))",
 392                   "x + 10",
 393                   "x1 + y2",
 394                   "(a + b) * (c - d)",
 395                   "func()",
 396                   "func(x, y + 2)",
 397                   "a * (b + c) - d / e",
 398                   "f(g(x), h(y, z))",
 399                   "123+456",
 400                   "123*456*789-123/456+789*123",
 401                   "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
 402                 // Failing strings
 403                 {
 404                     "+",
 405                     "/ 3x",
 406                     "x + + y",
 407                     "a * / b",
 408                     "func(,)",
 409                     "func(x y)",
 410                     "(a + b",
 411                     "x + y)",
 412                     "a + b * (c - d",
 413                     "42 +",
 414                     "x +",
 415                     "x + 10 +",
 416                     "(a + b) * (c - d",
 417                     "func(",
 418                     "func(x, y + 2",
 419                     "a * (b + c) - d /",
 420                     "f(g(x), h(y, z)",
 421                     "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
 422                 });
 423}
 424
 425static void test_special_chars() {
 426    // A collection of tests to exercise special characters such as "."
 427    test_grammar("special characters",
 428                 // Grammar
 429                 R"""(
 430            start: /.../ "abc" /.../
 431            )""",
 432                 // Passing strings
 433                 { "abcabcabc", "aaaabcccc",
 434                   // NOTE: Also ensures that multi-byte characters still count as a single character
 435                   "๐Ÿ”ต๐ŸŸ โœ…abcโŒ๐ŸŸ ๐Ÿ”ต" },
 436                 // Failing strings
 437                 { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "๐Ÿ”ต๐ŸŸ โœ…โŒabcโŒโœ…๐ŸŸ ๐Ÿ”ต", "๐Ÿ”ต๐ŸŸ abc๐ŸŸ ๐Ÿ”ต" });
 438}
 439
 440static void test_quantifiers() {
 441    // A collection of tests to exercise * + and ? quantifiers
 442
 443    test_grammar(
 444        "* quantifier",
 445        // Grammar
 446        R"""(start: "a"*)""",
 447        // Passing strings
 448        { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
 449        // Failing strings
 450        { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
 451    test_grammar(
 452        "+ quantifier",
 453        // Grammar
 454        R"""(start: "a"+)""",
 455        // Passing strings
 456        { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
 457        // Failing strings
 458        { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
 459    test_grammar("? quantifier",
 460                 // Grammar
 461                 R"""(start: "a"?)""",
 462                 // Passing strings
 463                 { "", "a" },
 464                 // Failing strings
 465                 {
 466                     "b",
 467                     "ab",
 468                     "aa",
 469                     "ba",
 470                 });
 471    test_grammar("mixed quantifiers",
 472                 // Grammar
 473                 R"""(
 474            start: cons+ vowel* cons? (vowel cons)*
 475            vowel: /[aeiouy]/
 476            cons: /[bcdfghjklmnpqrstvwxyz]/
 477            )""",
 478                 // Passing strings
 479                 {
 480                     "yes",
 481                     "no",
 482                     "noyes",
 483                     "crwth",
 484                     "four",
 485                     "bryyyy",
 486                 },
 487                 // Failing strings
 488                 {
 489                     "yess",
 490                     "yesno",
 491                     "forty",
 492                     "catyyy",
 493                 });
 494    test_grammar("simple exact repetition",
 495                 // Grammar
 496                 R"""(
 497            start: /[ab]{4}/
 498        )""",
 499                 // Passing strings
 500                 {
 501                     "aaaa",
 502                     "bbbb",
 503                     "abab",
 504                 },
 505                 // Failing strings
 506                 {
 507                     "a",
 508                     "b",
 509                     "aaaaa",
 510                 });
 511    test_grammar("simple min repetition",
 512                 // Grammar
 513                 R"""(
 514            start: /[ab]{4,}/
 515        )""",
 516                 // Passing strings
 517                 {
 518                     "aaaa",
 519                     "aaaaab",
 520                     "bbbb",
 521                     "ababab",
 522                 },
 523                 // Failing strings
 524                 {
 525                     "",
 526                     "aba",
 527                 });
 528    test_grammar("simple max repetition",
 529                 // Grammar
 530                 R"""(
 531            start: /[ab]{0,4}/
 532        )""",
 533                 // Passing strings
 534                 {
 535                     "",
 536                     "a",
 537                     "aa",
 538                     "aaa",
 539                     "aaab",
 540                 },
 541                 // Failing strings
 542                 {
 543                     "aaaaa",
 544                 });
 545    // test_grammar("min / max repetition",
 546    //              // Grammar
 547    //              R"""(
 548    //         start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
 549    //     )""",
 550    //              // Passing strings
 551    //              {
 552    //                  "0xFF 0x12 0xAB",
 553    //                  "0xFF 0x12 0xAB 0x00 0x00",
 554    //              },
 555    //              // Failing strings
 556    //              {
 557    //                  "",
 558    //                  "0xFF",
 559    //                  "0xFF 0x12",
 560    //                  "0xFF 0x12 0xAB 0x00 0x00 0x00",
 561    //              });
 562}
 563
 564static void test_json_schema() {
 565    // Note that this is similar to the regular grammar tests,
 566    //  but we convert each json schema to a grammar before parsing.
 567    // Otherwise, this test structure is the same.
 568
 569    test_schema("empty schema (object)",
 570                // Schema
 571                R"""(
 572            {"type":"object"}
 573        )""",
 574                // Passing strings
 575                {
 576                    R"""({})""",
 577                    R"""({"foo": "bar"})""",
 578                },
 579                // Failing strings
 580                {
 581                    "",
 582                    "[]",
 583                    "null",
 584                    R"""("")""",
 585                    "true",
 586                });
 587
 588    test_schema(
 589        "exotic formats (list)",
 590        // Schema
 591        R"""({
 592            "items": [
 593                { "format": "date" },
 594                { "format": "uuid" },
 595                { "format": "time" },
 596                { "format": "date-time" }
 597            ]
 598        })""",
 599        // Passing strings
 600        {
 601            // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
 602            // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
 603            R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
 604            //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
 605            //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
 606        },
 607        // Failing strings
 608        {
 609            R"""(["foo", "bar"])""",
 610            R"""(["12345678-1234-1234-1234-1234567890ab"])""",
 611        });
 612
 613    test_schema("string",
 614                // Schema
 615                R"""({
 616            "type": "string"
 617        })""",
 618                // Passing strings
 619                {
 620                    R"""("foo")""",
 621                    R"""("bar")""",
 622                    R"""("")""",
 623                },
 624                // Failing strings
 625                {
 626                    R"""({})""",
 627                    R"""("foo": "bar")""",
 628                });
 629
 630    test_schema("string w/ min length 1",
 631                // Schema
 632                R"""({
 633            "type": "string",
 634            "minLength": 1
 635        })""",
 636                // Passing strings
 637                {
 638                    R"""("foo")""",
 639                    R"""("bar")""",
 640                },
 641                // Failing strings
 642                {
 643                    R"""("")""",
 644                    R"""({})""",
 645                    R"""("foo": "bar")""",
 646                });
 647
 648    test_schema("string w/ min length 3",
 649                // Schema
 650                R"""({
 651                "type": "string",
 652                "minLength": 3
 653        })""",
 654                // Passing strings
 655                {
 656                    R"""("foo")""",
 657                    R"""("bar")""",
 658                    R"""("foobar")""",
 659                },
 660                // Failing strings
 661                {
 662                    R"""("")""",
 663                    R"""("f")""",
 664                    R"""("fo")""",
 665                });
 666
 667    test_schema("string w/ max length",
 668                // Schema
 669                R"""({
 670            "type": "string",
 671            "maxLength": 3
 672        })""",
 673                // Passing strings
 674                {
 675                    R"""("foo")""",
 676                    R"""("bar")""",
 677                    R"""("")""",
 678                    R"""("f")""",
 679                    R"""("fo")""",
 680                },
 681                // Failing strings
 682                {
 683                    R"""("foobar")""",
 684                });
 685
 686    test_schema("string w/ min & max length",
 687                // Schema
 688                R"""({
 689            "type": "string",
 690            "minLength": 1,
 691            "maxLength": 4
 692        })""",
 693                // Passing strings
 694                {
 695                    R"""("foo")""",
 696                    R"""("bar")""",
 697                    R"""("f")""",
 698                    R"""("barf")""",
 699                },
 700                // Failing strings
 701                {
 702                    R"""("")""",
 703                    R"""("barfo")""",
 704                    R"""("foobar")""",
 705                });
 706
 707    test_schema("boolean",
 708                // Schema
 709                R"""({
 710            "type": "boolean"
 711        })""",
 712                // Passing strings
 713                {
 714                    "true",
 715                    "false",
 716                },
 717                // Failing strings
 718                {
 719                    R"""("")""",
 720                    R"""("true")""",
 721                    R"""(True)""",
 722                    R"""(FALSE)""",
 723                });
 724
 725    test_schema("integer",
 726                // Schema
 727                R"""({
 728            "type": "integer"
 729        })""",
 730                // Passing strings
 731                {
 732                    R"""(0)""",
 733                    R"""(12345)""",
 734                    R"""(1234567890123456)""",
 735                },
 736                // Failing strings
 737                {
 738                    R"""()""",
 739                    R"""(01)""",
 740                    R"""(007)""",
 741                    R"""(12345678901234567  )""",
 742                });
 743
 744    test_schema("string const",
 745                // Schema
 746                R"""({
 747            "const": "foo"
 748        })""",
 749                // Passing strings
 750                {
 751                    R"""("foo")""",
 752                },
 753                // Failing strings
 754                {
 755                    R"""(foo)""",
 756                    R"""("bar")""",
 757                });
 758
 759    test_schema("non-string const",
 760                // Schema
 761                R"""({
 762            "const": true
 763        })""",
 764                // Passing strings
 765                {
 766                    R"""(true)""",
 767                },
 768                // Failing strings
 769                {
 770                    R"""()""",
 771                    R"""(foo)""",
 772                    R"""("true")""",
 773                });
 774
 775    test_schema("non-string const",
 776                // Schema
 777                R"""({
 778            "enum": ["red", "amber", "green", null, 42, ["foo"]]
 779        })""",
 780                // Passing strings
 781                {
 782                    R"""("red")""",
 783                    R"""(null)""",
 784                    R"""(42)""",
 785                    R"""(["foo"])""",
 786                },
 787                // Failing strings
 788                {
 789                    R"""()""",
 790                    R"""(420)""",
 791                    R"""(true)""",
 792                    R"""(foo)""",
 793                });
 794
 795    test_schema("simple pattern",
 796                // Schema
 797                R"""({
 798            "pattern": "^[a-zA-Z0-9_-]*$"
 799        })""",
 800                // Passing strings
 801                {
 802                    R"""("")""",
 803                    R"""("He_llo-12")""",
 804                },
 805                // Failing strings
 806                {
 807                    R"""("!")""",
 808                    R"""("Hello World")""",
 809                });
 810
 811    test_schema("pattern with escapes",
 812                // Schema
 813                R"""({
 814            "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
 815        })""",
 816                // Passing strings
 817                {
 818                    R"""("a^$.[]()|{}*+?b")""",
 819                },
 820                // Failing strings
 821                {
 822                    R"""("ab")""",
 823                });
 824
 825    test_schema("",
 826                // Schema
 827                R"""(
 828            {
 829                "type": ["array", "null"],
 830                "items": { "type": "string" }
 831            }
 832        )""",
 833                // Passing strings
 834                {
 835                    "null",
 836                    "[]",
 837                    "[\"123\"]",
 838                    "[\"foo\", \"bar\"]",
 839                },
 840                // Failing strings
 841                {
 842                    "",
 843                    "[123]",
 844                    "\"foo\"",
 845                    "[\"foo\", 42]",
 846                });
 847
 848    test_schema("min+max items",
 849                // Schema
 850                R"""({
 851            "items": {
 852                "type": ["number", "integer"]
 853            },
 854            "minItems": 3,
 855            "maxItems": 5
 856        })""",
 857                // Passing strings
 858                {
 859                    R"""([1, 2, 3])""",
 860                    R"""([1, 2, 3, 4])""",
 861                    R"""([1, 2, 3, 4, 5])""",
 862                    // this is in fact correct; keyword do not apply if the type is wrong
 863                    R"""(1)""",
 864                },
 865                // Failing strings
 866                {
 867                    R"""([1, 2])""",
 868                    R"""([1, 2, 3, 4, 5, 6])""",
 869                });
 870
 871    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
 872    test_schema("object properties",
 873                // Schema
 874                R"""({
 875            "type": "object",
 876            "properties": {
 877                "number": { "type": "number" },
 878                "street_name": { "type": "string" },
 879                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
 880            },
 881            "additionalProperties": false
 882        })""",
 883                // Passing strings
 884                {
 885                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
 886                    // "By default, leaving out properties is valid"
 887                    R"""({ "street_name": "Pennsylvania" })""",
 888                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
 889                    // "By extension, even an empty object is valid"
 890                    R"""({})""",
 891                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
 892                },
 893                // Failing strings
 894                {
 895                    // Change datatype from number to string
 896                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
 897                    // Reorder properties
 898                    R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
 899                    // Reorder properties
 900                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
 901                    // Additional properties set to false
 902                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
 903
 904                });
 905
 906    test_schema("additional properties can't override other properties",
 907                R"""({
 908            "properties": {
 909                "a": {"type": "integer"},
 910                "b": {"type": "integer"}
 911            },
 912            "additionalProperties": true
 913        })""",
 914                // Passing strings
 915                {
 916                    R"""({"a": 42})""",
 917                    R"""({"c": ""})""",
 918                    R"""({"a": 42, "c": ""})""",
 919                    R"""({"a_": ""})""",
 920                },
 921                // Failing strings
 922                {
 923                    R"""()""",
 924                    R"""({"a": ""})""",
 925                    R"""({"a": "", "b": ""})""",
 926                });
 927
 928    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
 929    test_schema("object properties, additionalProperties: true",
 930                // Schema
 931                R"""({
 932            "type": "object",
 933            "properties": {
 934                "number": { "type": "number" },
 935                "street_name": { "type": "string" },
 936                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
 937            },
 938            "additionalProperties": true
 939        })""",
 940                // Passing strings
 941                {
 942                    // "By extension, even an empty object is valid"
 943                    R"""({})""",
 944                    R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
 945                    // "By default, leaving out properties is valid"
 946                    R"""({ "street_name": "Pennsylvania" })""",
 947                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
 948                    // "By default, providing additional properties is valid"
 949                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
 950                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
 951                },
 952                // Failing strings
 953                {
 954                    // Change datatype from number to string
 955                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
 956                    // Reorder properties
 957                    R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
 958                });
 959
 960    // Additional properties: false
 961    test_schema(
 962        "required + optional props each in original order",
 963        // Schema
 964        R"""({
 965            "type": "object",
 966            "properties": {
 967                "number": { "type": "number" },
 968                "street_name": { "type": "string" },
 969                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
 970            },
 971            "additionalProperties": false
 972        })""",
 973        // Passing strings
 974        {
 975            R"""({ "street_name": "Pennsylvania" })""",
 976            R"""({ "number": 1600, "street_type":"Avenue"})""",
 977            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
 978            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
 979            // Spaces are permitted around enum values
 980            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
 981        },
 982        // Failing strings
 983        {
 984            // Reorder properties
 985            R"""({ "street_type": "Avenue", "number": 1600 })""",
 986            // Add "direction"
 987            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
 988        });
 989
 990    test_schema("required + optional props each in original order",
 991                // Schema
 992                R"""({
 993            "properties": {
 994                "b": {"type": "string"},
 995                "a": {"type": "string"},
 996                "d": {"type": "string"},
 997                "c": {"type": "string"}
 998            },
 999            "required": ["a", "b"],
1000            "additionalProperties": false
1001        })""",
1002                // Passing strings
1003                {
1004                    R"""({"b": "foo", "a": "bar"})""",
1005                    R"""({"b":"foo","a":"bar","d":"qux"})""",
1006                    R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
1007                },
1008                // Failing strings
1009                {
1010                    R"""({"a": "foo", "b": "bar"})""",
1011                    R"""({"b": "bar"})""",
1012                    R"""({"a": "foo", "c": "baz"})""",
1013                    R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
1014                });
1015
1016    // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
1017    test_schema(
1018        "required props",
1019        // Schema
1020        R"""({
1021            "$schema": "https://json-schema.org/draft/2020-12/schema",
1022            "$id": "https://example.com/product.schema.json",
1023            "title": "Product",
1024            "description": "A product from Acme's catalog",
1025            "type": "object",
1026            "properties": {
1027                "productId": {
1028                "description": "The unique identifier for a product",
1029                "type": "integer"
1030                },
1031                "productName": {
1032                "description": "Name of the product",
1033                "type": "string"
1034                },
1035                "price": {
1036                "description": "The price of the product",
1037                "type": "number",
1038                "exclusiveMinimum": 0
1039                },
1040                "tags": {
1041                "description": "Tags for the product",
1042                "type": "array",
1043                "items": {
1044                    "type": "string"
1045                },
1046                "minItems": 1,
1047                "DISABLED_uniqueItems": true
1048                },
1049                "dimensions": {
1050                "type": "object",
1051                "properties": {
1052                    "length": {
1053                    "type": "number"
1054                    },
1055                    "width": {
1056                    "type": "number"
1057                    },
1058                    "height": {
1059                    "type": "number"
1060                    }
1061                },
1062                "required": [ "length", "width", "height" ]
1063                }
1064            },
1065            "required": [ "productId", "productName", "price" ]
1066        })""",
1067        // Passing strings
1068        {
1069            R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
1070            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
1071            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
1072        },
1073        // Failing strings
1074        {
1075            R"""({})""",  // Missing all required properties
1076            R"""({"productName": "A green door", "price": 12.50, "productId": 1})""",  // Out of order properties
1077            // `exclusiveMinimum` is OK for llg
1078            R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
1079            R"""({"productId": 1, "productName": "A green door"})""",  // Missing required property (price)
1080            R"""({"productName": "A green door", "price": 12.50})""",  // Missing required property (productId)
1081            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""",  // tags is empty, but minItems is 1
1082            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""",  // Tags and dimensions are out of order
1083            // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
1084            // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
1085        });
1086}
1087
1088static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
1089    auto n_vocab = tok_arr.size;
1090
1091    tok_arr.selected = -1;
1092    tok_arr.sorted   = false;
1093    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1094        tok_arr.data[token_id].id    = token_id;
1095        tok_arr.data[token_id].logit = 0.0f;
1096    }
1097
1098    tok_arr.data[selected].logit = 100.0f;
1099}
1100
1101static void test_sampler_chain(void) {
1102    auto sparams            = llama_sampler_chain_default_params();
1103    sparams.no_perf         = false;
1104    llama_sampler * sampler = llama_sampler_chain_init(sparams);
1105
1106    const auto grammar_data = R"(%llguidance {}
1107start: /[A-Z ]*/)";
1108
1109    llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
1110    llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
1111
1112    auto input  = "ALL YOUR BASE ARE BELONG TO US";
1113    auto tokens = common_tokenize(vocab, input, false, false);
1114
1115    auto n_vocab = llama_vocab_n_tokens(vocab);
1116
1117    std::vector<llama_token_data> cur;
1118    cur.reserve(n_vocab);
1119    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1120        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
1121    }
1122    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
1123
1124    for (const auto token : tokens) {
1125        one_hot(tok_arr, token);
1126
1127        fprintf(stderr, "applying token: %d\n", token);
1128        llama_sampler_apply(sampler, &tok_arr);
1129
1130        auto idx = tok_arr.selected;
1131        fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
1132        assert(cur[tok_arr.selected].id == token);
1133        llama_sampler_accept(sampler, token);
1134    }
1135
1136    auto tok_eos = llama_vocab_eot(vocab);
1137    if (tok_eos == LLAMA_TOKEN_NULL) {
1138        tok_eos = llama_vocab_eos(vocab);
1139    }
1140
1141    one_hot(tok_arr, tok_eos);
1142
1143    llama_sampler_apply(sampler, &tok_arr);
1144    assert(cur[tok_arr.selected].id == tok_eos);
1145}
1146
1147int main(int argc, const char ** argv) {
1148    fprintf(stdout, "Running llguidance integration tests...\n");
1149
1150    if (argc != 2) {
1151        fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
1152        return 1;
1153    }
1154
1155    const char * vocab_file = argv[1];
1156
1157    fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
1158
1159    llama_model *   model;
1160    llama_context * ctx;
1161
1162    llama_backend_init();
1163
1164    // load the vocab
1165    {
1166        auto mparams = llama_model_default_params();
1167
1168        mparams.vocab_only = true;
1169
1170        model = llama_model_load_from_file(vocab_file, mparams);
1171
1172        if (model == NULL) {
1173            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
1174            return 1;
1175        }
1176
1177        // needed?
1178        auto cparams = llama_context_default_params();
1179
1180        ctx = llama_init_from_model(model, cparams);
1181
1182        if (ctx == NULL) {
1183            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
1184            llama_model_free(model);
1185            return 1;
1186        }
1187    }
1188
1189    vocab = llama_model_get_vocab(model);
1190
1191    test_simple_grammar();
1192    test_complex_grammar();
1193    test_special_chars();
1194    test_quantifiers();
1195    test_json_schema();
1196
1197    test_sampler_chain();
1198
1199    llama_free(ctx);
1200    llama_model_free(model);
1201
1202    fprintf(stdout, "All tests passed.\n");
1203    return 0;
1204}