1#pragma once
2
3#include "llama.h"
4
5#include <map>
6#include <regex>
7#include <string>
8#include <vector>
9
10struct llama_vocab;
11
12// grammar element type
13enum llama_gretype {
14 // end of rule definition
15 LLAMA_GRETYPE_END = 0,
16
17 // start of alternate definition for rule
18 LLAMA_GRETYPE_ALT = 1,
19
20 // non-terminal element: reference to rule
21 LLAMA_GRETYPE_RULE_REF = 2,
22
23 // terminal element: character (code point)
24 LLAMA_GRETYPE_CHAR = 3,
25
26 // inverse char(s) ([^a], [^a-b] [^abc])
27 LLAMA_GRETYPE_CHAR_NOT = 4,
28
29 // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
30 // be an inclusive range ([a-z])
31 LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
32
33 // modifies a preceding LLAMA_GRETYPE_CHAR or
34 // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
35 LLAMA_GRETYPE_CHAR_ALT = 6,
36
37 // any character (.)
38 LLAMA_GRETYPE_CHAR_ANY = 7,
39
40 // terminal element: token (<[token-id]>)
41 LLAMA_GRETYPE_TOKEN = 8,
42
43 // inverse token (!<[token-id]>)
44 LLAMA_GRETYPE_TOKEN_NOT = 9,
45};
46
47typedef struct llama_grammar_element {
48 enum llama_gretype type;
49 uint32_t value; // Unicode code point, rule ID, or token ID
50} llama_grammar_element;
51
52struct llama_partial_utf8 {
53 uint32_t value; // bit value so far (unshifted)
54 int n_remain; // num bytes remaining; -1 indicates invalid sequence
55};
56
57struct llama_grammar_candidate {
58 size_t index;
59 const uint32_t * code_points;
60 llama_partial_utf8 partial_utf8;
61 llama_token id;
62};
63
64using llama_grammar_rule = std::vector< llama_grammar_element>;
65using llama_grammar_stack = std::vector<const llama_grammar_element *>;
66
67using llama_grammar_rules = std::vector<llama_grammar_rule>;
68using llama_grammar_stacks = std::vector<llama_grammar_stack>;
69using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
70
71// TODO: remove, needed for tests atm
72const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
73 llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
74
75// takes a set of possible pushdown stacks on a grammar, which are required to
76// be positioned at a character range (see `llama_grammar_advance_stack`), and
77// produces the N possible stacks if the given char is accepted at those
78// positions
79void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
80
81std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
82 const llama_grammar_rules & rules,
83 const llama_grammar_stack & stack,
84 const llama_grammar_candidates & candidates);
85
86struct llama_grammar_parser {
87 const llama_vocab * vocab;
88 std::map<std::string, uint32_t> symbol_ids;
89
90 llama_grammar_rules rules;
91
92 llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
93
94 llama_grammar_stack c_rules() const;
95
96 uint32_t get_symbol_id(const char * src, size_t len);
97 uint32_t generate_symbol_id(const std::string & base_name);
98
99 void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
100
101 const char * parse_alternates(
102 const char * src,
103 const std::string & rule_name,
104 uint32_t rule_id,
105 bool is_nested);
106
107 const char * parse_sequence(
108 const char * src,
109 const std::string & rule_name,
110 llama_grammar_rule & rule,
111 bool is_nested);
112
113 const char * parse_rule(const char * src);
114
115 bool parse(const char * src);
116 void print(FILE * file);
117};
118
119struct llama_grammar_trigger_pattern {
120 std::string pattern;
121 std::regex regex;
122
123 size_t find(const std::string & input) const;
124};
125
126struct llama_grammar {
127 // maintain a list of llama_tokens and their positions in the trigger_buffer
128 using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
129
130 // note: allow null vocab for testing (not great)
131 const llama_vocab * vocab;
132
133 const llama_grammar_rules rules; // TODO: shared ptr
134 llama_grammar_stacks stacks;
135
136 // buffer for partially generated UTF-8 sequence from accepted tokens
137 llama_partial_utf8 partial_utf8;
138
139 // lazy grammars wait for trigger words or tokens before constraining the sampling.
140 // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
141 // (useful e.g. for tool_choice=required)
142 bool lazy = false;
143 bool awaiting_trigger = false; // Initialized to true for lazy grammars only
144 std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
145 std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
146 std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
147 std::vector<llama_grammar_trigger_pattern>
148 trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
149 // string, and the grammar will be given the string from the first match group onwards.
150
151};
152
153//
154// internal API
155//
156
157// note: needed for tests (not great)
158struct llama_grammar * llama_grammar_init_impl(
159 const struct llama_vocab * vocab,
160 const llama_grammar_element ** rules,
161 size_t n_rules,
162 size_t start_rule_index);
163
164struct llama_grammar * llama_grammar_init_impl(
165 const struct llama_vocab * vocab,
166 const char * grammar_str,
167 const char * grammar_root,
168 bool lazy,
169 const char ** trigger_patterns,
170 size_t num_trigger_patterns,
171 const llama_token * trigger_tokens,
172 size_t num_trigger_tokens);
173
174void llama_grammar_free_impl(struct llama_grammar * grammar);
175
176struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
177
178// TODO: move the API below as member functions of llama_grammar
179void llama_grammar_apply_impl(
180 const struct llama_grammar & grammar,
181 llama_token_data_array * cur_p);
182
183void llama_grammar_accept_impl(
184 struct llama_grammar & grammar,
185 llama_token token);
186
187void llama_grammar_accept_str(
188 struct llama_grammar & grammar,
189 const std::string & piece);
190
191void llama_grammar_accept_token(
192 struct llama_grammar & grammar,
193 llama_token token,
194 const std::string & piece);