1#include "sampling.h"
2#include "log.h"
3
4#ifdef LLAMA_USE_LLGUIDANCE
5
6# include "llguidance.h"
7# include <cmath>
8
9struct llama_sampler_llg {
10 const llama_vocab * vocab;
11 std::string grammar_kind;
12 std::string grammar_data;
13 LlgTokenizer * tokenizer;
14 LlgMatcher * grammar;
15};
16
17static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
18 const char * grammar_data) {
19 LlgConstraintInit cinit;
20 llg_constraint_init_set_defaults(&cinit, tokenizer);
21 const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
22 if (log_level && *log_level) {
23 cinit.log_stderr_level = atoi(log_level);
24 }
25 auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
26 if (llg_matcher_get_error(c)) {
27 LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
28 llg_free_matcher(c);
29 return nullptr;
30 }
31
32 return c;
33}
34
35static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
36 return "llguidance";
37}
38
39static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
40 auto * ctx = (llama_sampler_llg *) smpl->ctx;
41 if (ctx->grammar) {
42 llg_matcher_consume_token(ctx->grammar, token);
43 }
44}
45
46static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
47 auto * ctx = (llama_sampler_llg *) smpl->ctx;
48 if (ctx->grammar) {
49 const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
50 if (mask == nullptr) {
51 if (llg_matcher_compute_mask(ctx->grammar) == 0) {
52 mask = llg_matcher_get_mask(ctx->grammar);
53 } else {
54 LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
55 llg_free_matcher(ctx->grammar);
56 ctx->grammar = nullptr;
57 return;
58 }
59 }
60
61 for (size_t i = 0; i < cur_p->size; ++i) {
62 auto token = cur_p->data[i].id;
63 if ((mask[token / 32] & (1 << (token % 32))) == 0) {
64 cur_p->data[i].logit = -INFINITY;
65 }
66 }
67 }
68}
69
70static void llama_sampler_llg_reset(llama_sampler * smpl) {
71 auto * ctx = (llama_sampler_llg *) smpl->ctx;
72 if (ctx->grammar) {
73 llg_matcher_reset(ctx->grammar);
74 }
75}
76
77static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
78 const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
79
80 auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
81
82 // copy the state
83 {
84 auto * result_ctx = (llama_sampler_llg *) result->ctx;
85
86 if (ctx->grammar) {
87 result_ctx->grammar_kind = ctx->grammar_kind;
88 result_ctx->grammar_data = ctx->grammar_data;
89 result_ctx->grammar = llg_clone_matcher(ctx->grammar);
90 result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
91 }
92 }
93
94 return result;
95}
96
97static void llama_sampler_llg_free(llama_sampler * smpl) {
98 const auto * ctx = (llama_sampler_llg *) smpl->ctx;
99
100 if (ctx->grammar) {
101 llg_free_matcher(ctx->grammar);
102 llg_free_tokenizer(ctx->tokenizer);
103 }
104
105 delete ctx;
106}
107
108static llama_sampler_i llama_sampler_llg_i = {
109 /* .name = */ llama_sampler_llg_name,
110 /* .accept = */ llama_sampler_llg_accept_impl,
111 /* .apply = */ llama_sampler_llg_apply,
112 /* .reset = */ llama_sampler_llg_reset,
113 /* .clone = */ llama_sampler_llg_clone,
114 /* .free = */ llama_sampler_llg_free,
115 /* .backend_init = */ NULL,
116 /* .backend_accept = */ NULL,
117 /* .backend_apply = */ NULL,
118 /* .backend_set_input = */ NULL,
119};
120
121static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
122 uint32_t * output_tokens, size_t output_tokens_len) {
123 const llama_vocab * vocab = (const llama_vocab *) user_data;
124 int r = 0;
125 try {
126 r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
127 true);
128 } catch (const std::exception & e) {
129 GGML_ABORT("llama_tokenize failed: %s\n", e.what());
130 }
131 if (r < 0) {
132 return -r;
133 }
134 return r;
135}
136
137static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
138 // TODO store the tokenizer in the vocab somehow
139 static const llama_vocab * vocab_cache;
140 static LlgTokenizer * tokenizer_cache;
141
142 if (vocab_cache == vocab) {
143 return llg_clone_tokenizer(tokenizer_cache);
144 }
145
146 auto tok_eos = llama_vocab_eot(vocab);
147 if (tok_eos == LLAMA_TOKEN_NULL) {
148 tok_eos = llama_vocab_eos(vocab);
149 }
150
151 size_t vocab_size = llama_vocab_n_tokens(vocab);
152
153 auto token_lens = new uint32_t[vocab_size];
154 // we typically have ~7 bytes per token; let's go on the safe side here
155 auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
156 auto token_bytes = new uint8_t[token_bytes_size];
157
158 size_t offset = 0;
159 for (size_t i = 0; i < vocab_size; i++) {
160 size_t max_token = 1024;
161 if (token_bytes_size - offset < max_token) {
162 GGML_ABORT("token_bytes buffer too small\n");
163 }
164
165 llama_token token = i;
166 auto dp = (char *) token_bytes + offset;
167 auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
168 if (size < 0) {
169 GGML_ABORT("llama_detokenize failed\n");
170 }
171 if (size == 0) {
172 size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
173 if (size < 0) {
174 GGML_ABORT("llama_detokenize failed\n");
175 }
176 if (size != 0) {
177 *dp = '\xff'; // special token prefix marker
178 size += 1;
179 }
180 }
181
182 token_lens[i] = size;
183 offset += size;
184 }
185
186 LlgTokenizerInit tinit = {
187 /* .vocab_size = */ (uint32_t) vocab_size,
188 /* .tok_eos = */ (uint32_t) tok_eos,
189 /* .token_lens = */ token_lens,
190 /* .token_bytes = */ token_bytes,
191 /* .tokenizer_json = */ nullptr,
192 /* .tokenize_assumes_string = */ true,
193 /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
194 /* .use_approximate_greedy_tokenize_fn = */ false,
195 /* .tokenize_user_data = */ vocab,
196 /* .slices = */ nullptr,
197 };
198
199 char error_buffer[1024];
200 LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
201
202 delete[] token_bytes;
203 delete[] token_lens;
204
205 if (tokenizer == nullptr) {
206 LOG_ERR("llg tokenizer error: %s\n", error_buffer);
207 return tokenizer;
208 }
209
210 if (tokenizer_cache) {
211 llg_free_tokenizer(tokenizer_cache);
212 }
213 vocab_cache = vocab;
214 tokenizer_cache = tokenizer;
215
216 return llg_clone_tokenizer(tokenizer_cache);
217}
218
219llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
220 const char * grammar_data) {
221 auto * ctx = new llama_sampler_llg;
222
223 if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
224 auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
225 *ctx = {
226 /* .vocab = */ vocab,
227 /* .grammar_kind = */ grammar_kind,
228 /* .grammar_data = */ grammar_data,
229 /* .tokenizer = */ tokenizer,
230 /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
231 };
232 if (ctx->grammar) {
233 GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
234 llg_matcher_get_mask_byte_size(ctx->grammar));
235 }
236 } else {
237 *ctx = {
238 /* .vocab = */ vocab,
239 /* .grammar_kind = */ {},
240 /* .grammar_data = */ {},
241 /* .tokenizer = */ nullptr,
242 /* .grammar = */ nullptr,
243 };
244 }
245
246 return llama_sampler_init(
247 /* .iface = */ &llama_sampler_llg_i,
248 /* .ctx = */ ctx);
249}
250
251#else
252
253llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
254 LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
255 return nullptr;
256}
257
258#endif // LLAMA_USE_LLGUIDANCE