1#pragma once
  2
  3#include "llama.h"
  4
  5#include <map>
  6#include <regex>
  7#include <string>
  8#include <vector>
  9
 10struct llama_vocab;
 11
 12// grammar element type
 13enum llama_gretype {
 14    // end of rule definition
 15    LLAMA_GRETYPE_END            = 0,
 16
 17    // start of alternate definition for rule
 18    LLAMA_GRETYPE_ALT            = 1,
 19
 20    // non-terminal element: reference to rule
 21    LLAMA_GRETYPE_RULE_REF       = 2,
 22
 23    // terminal element: character (code point)
 24    LLAMA_GRETYPE_CHAR           = 3,
 25
 26    // inverse char(s) ([^a], [^a-b] [^abc])
 27    LLAMA_GRETYPE_CHAR_NOT       = 4,
 28
 29    // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
 30    // be an inclusive range ([a-z])
 31    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
 32
 33    // modifies a preceding LLAMA_GRETYPE_CHAR or
 34    // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
 35    LLAMA_GRETYPE_CHAR_ALT       = 6,
 36
 37    // any character (.)
 38    LLAMA_GRETYPE_CHAR_ANY       = 7,
 39
 40    // terminal element: token (<[token-id]>)
 41    LLAMA_GRETYPE_TOKEN          = 8,
 42
 43    // inverse token (!<[token-id]>)
 44    LLAMA_GRETYPE_TOKEN_NOT      = 9,
 45};
 46
 47typedef struct llama_grammar_element {
 48    enum llama_gretype type;
 49    uint32_t           value; // Unicode code point, rule ID, or token ID
 50} llama_grammar_element;
 51
 52struct llama_partial_utf8 {
 53    uint32_t value;    // bit value so far (unshifted)
 54    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
 55};
 56
 57struct llama_grammar_candidate {
 58    size_t               index;
 59    const uint32_t     * code_points;
 60    llama_partial_utf8   partial_utf8;
 61    llama_token          id;
 62};
 63
 64using llama_grammar_rule  = std::vector<      llama_grammar_element>;
 65using llama_grammar_stack = std::vector<const llama_grammar_element *>;
 66
 67using llama_grammar_rules      = std::vector<llama_grammar_rule>;
 68using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
 69using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
 70
 71// TODO: remove, needed for tests atm
 72const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
 73      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 74
 75// takes a set of possible pushdown stacks on a grammar, which are required to
 76// be positioned at a character range (see `llama_grammar_advance_stack`), and
 77// produces the N possible stacks if the given char is accepted at those
 78// positions
 79void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
 80
 81std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
 82        const llama_grammar_rules      & rules,
 83        const llama_grammar_stack      & stack,
 84        const llama_grammar_candidates & candidates);
 85
 86struct llama_grammar_parser {
 87    const llama_vocab * vocab;
 88    std::map<std::string, uint32_t> symbol_ids;
 89
 90    llama_grammar_rules rules;
 91
 92    llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
 93
 94    llama_grammar_stack c_rules() const;
 95
 96    uint32_t get_symbol_id(const char * src, size_t len);
 97    uint32_t generate_symbol_id(const std::string & base_name);
 98
 99    void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
100
101    const char * parse_alternates(
102            const char        * src,
103            const std::string & rule_name,
104            uint32_t            rule_id,
105            bool                is_nested);
106
107    const char * parse_sequence(
108            const char         * src,
109            const std::string  & rule_name,
110            llama_grammar_rule & rule,
111            bool               is_nested);
112
113    const char * parse_rule(const char * src);
114
115    bool parse(const char * src);
116    void print(FILE * file);
117};
118
119struct llama_grammar_trigger_pattern {
120    std::string pattern;
121    std::regex  regex;
122
123    size_t find(const std::string & input) const;
124};
125
126struct llama_grammar {
127    // maintain a list of llama_tokens and their positions in the trigger_buffer
128    using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
129
130    // note: allow null vocab for testing (not great)
131    const llama_vocab * vocab;
132
133    const llama_grammar_rules  rules;  // TODO: shared ptr
134          llama_grammar_stacks stacks;
135
136    // buffer for partially generated UTF-8 sequence from accepted tokens
137    llama_partial_utf8 partial_utf8;
138
139    // lazy grammars wait for trigger words or tokens before constraining the sampling.
140    // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
141    // (useful e.g. for tool_choice=required)
142    bool                     lazy             = false;
143    bool                     awaiting_trigger = false; // Initialized to true for lazy grammars only
144    std::string              trigger_buffer;           // Output buffered by lazy grammar. Will be cleared once trigger is found.
145    std::vector<token_pos>   trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
146    std::vector<llama_token> trigger_tokens;           // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
147    std::vector<llama_grammar_trigger_pattern>
148                             trigger_patterns;         // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
149                                                       // string, and the grammar will be given the string from the first match group onwards.
150
151};
152
153//
154// internal API
155//
156
157// note: needed for tests (not great)
158struct llama_grammar * llama_grammar_init_impl(
159        const struct llama_vocab * vocab,
160        const llama_grammar_element ** rules,
161        size_t n_rules,
162        size_t start_rule_index);
163
164struct llama_grammar * llama_grammar_init_impl(
165        const struct llama_vocab * vocab,
166                      const char * grammar_str,
167                      const char * grammar_root,
168                              bool lazy,
169                     const char ** trigger_patterns,
170                            size_t num_trigger_patterns,
171               const llama_token * trigger_tokens,
172                            size_t num_trigger_tokens);
173
174void llama_grammar_free_impl(struct llama_grammar * grammar);
175
176struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
177
178// TODO: move the API below as member functions of llama_grammar
179void llama_grammar_apply_impl(
180        const struct llama_grammar & grammar,
181            llama_token_data_array * cur_p);
182
183void llama_grammar_accept_impl(
184              struct llama_grammar & grammar,
185                       llama_token   token);
186
187void llama_grammar_accept_str(
188              struct llama_grammar & grammar,
189                 const std::string & piece);
190
191void llama_grammar_accept_token(
192              struct llama_grammar & grammar,
193                       llama_token   token,
194                 const std::string & piece);