1#include "sampling.h"
  2#include "log.h"
  3
  4#ifdef LLAMA_USE_LLGUIDANCE
  5
  6#    include "llguidance.h"
  7#    include <cmath>
  8
  9struct llama_sampler_llg {
 10    const llama_vocab * vocab;
 11    std::string         grammar_kind;
 12    std::string         grammar_data;
 13    LlgTokenizer *      tokenizer;
 14    LlgMatcher *        grammar;
 15};
 16
 17static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
 18                                          const char * grammar_data) {
 19    LlgConstraintInit cinit;
 20    llg_constraint_init_set_defaults(&cinit, tokenizer);
 21    const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
 22    if (log_level && *log_level) {
 23        cinit.log_stderr_level = atoi(log_level);
 24    }
 25    auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
 26    if (llg_matcher_get_error(c)) {
 27        LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
 28        llg_free_matcher(c);
 29        return nullptr;
 30    }
 31
 32    return c;
 33}
 34
 35static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
 36    return "llguidance";
 37}
 38
 39static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
 40    auto * ctx = (llama_sampler_llg *) smpl->ctx;
 41    if (ctx->grammar) {
 42        llg_matcher_consume_token(ctx->grammar, token);
 43    }
 44}
 45
 46static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
 47    auto * ctx = (llama_sampler_llg *) smpl->ctx;
 48    if (ctx->grammar) {
 49        const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
 50        if (mask == nullptr) {
 51            if (llg_matcher_compute_mask(ctx->grammar) == 0) {
 52                mask = llg_matcher_get_mask(ctx->grammar);
 53            } else {
 54                LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
 55                llg_free_matcher(ctx->grammar);
 56                ctx->grammar = nullptr;
 57                return;
 58            }
 59        }
 60
 61        for (size_t i = 0; i < cur_p->size; ++i) {
 62            auto token = cur_p->data[i].id;
 63            if ((mask[token / 32] & (1 << (token % 32))) == 0) {
 64                cur_p->data[i].logit = -INFINITY;
 65            }
 66        }
 67    }
 68}
 69
 70static void llama_sampler_llg_reset(llama_sampler * smpl) {
 71    auto * ctx = (llama_sampler_llg *) smpl->ctx;
 72    if (ctx->grammar) {
 73        llg_matcher_reset(ctx->grammar);
 74    }
 75}
 76
 77static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
 78    const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
 79
 80    auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
 81
 82    // copy the state
 83    {
 84        auto * result_ctx = (llama_sampler_llg *) result->ctx;
 85
 86        if (ctx->grammar) {
 87            result_ctx->grammar_kind = ctx->grammar_kind;
 88            result_ctx->grammar_data = ctx->grammar_data;
 89            result_ctx->grammar      = llg_clone_matcher(ctx->grammar);
 90            result_ctx->tokenizer    = llg_clone_tokenizer(ctx->tokenizer);
 91        }
 92    }
 93
 94    return result;
 95}
 96
 97static void llama_sampler_llg_free(llama_sampler * smpl) {
 98    const auto * ctx = (llama_sampler_llg *) smpl->ctx;
 99
100    if (ctx->grammar) {
101        llg_free_matcher(ctx->grammar);
102        llg_free_tokenizer(ctx->tokenizer);
103    }
104
105    delete ctx;
106}
107
108static llama_sampler_i llama_sampler_llg_i = {
109    /* .name              = */ llama_sampler_llg_name,
110    /* .accept            = */ llama_sampler_llg_accept_impl,
111    /* .apply             = */ llama_sampler_llg_apply,
112    /* .reset             = */ llama_sampler_llg_reset,
113    /* .clone             = */ llama_sampler_llg_clone,
114    /* .free              = */ llama_sampler_llg_free,
115    /* .backend_init      = */ NULL,
116    /* .backend_accept    = */ NULL,
117    /* .backend_apply     = */ NULL,
118    /* .backend_set_input = */ NULL,
119};
120
121static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
122                                            uint32_t * output_tokens, size_t output_tokens_len) {
123    const llama_vocab * vocab = (const llama_vocab *) user_data;
124    int                 r     = 0;
125    try {
126        r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
127                           true);
128    } catch (const std::exception & e) {
129        GGML_ABORT("llama_tokenize failed: %s\n", e.what());
130    }
131    if (r < 0) {
132        return -r;
133    }
134    return r;
135}
136
137static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
138    // TODO store the tokenizer in the vocab somehow
139    static const llama_vocab * vocab_cache;
140    static LlgTokenizer *      tokenizer_cache;
141
142    if (vocab_cache == vocab) {
143        return llg_clone_tokenizer(tokenizer_cache);
144    }
145
146    auto tok_eos = llama_vocab_eot(vocab);
147    if (tok_eos == LLAMA_TOKEN_NULL) {
148        tok_eos = llama_vocab_eos(vocab);
149    }
150
151    size_t vocab_size = llama_vocab_n_tokens(vocab);
152
153    auto token_lens       = new uint32_t[vocab_size];
154    // we typically have ~7 bytes per token; let's go on the safe side here
155    auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
156    auto token_bytes      = new uint8_t[token_bytes_size];
157
158    size_t offset = 0;
159    for (size_t i = 0; i < vocab_size; i++) {
160        size_t max_token = 1024;
161        if (token_bytes_size - offset < max_token) {
162            GGML_ABORT("token_bytes buffer too small\n");
163        }
164
165        llama_token token = i;
166        auto        dp    = (char *) token_bytes + offset;
167        auto        size  = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
168        if (size < 0) {
169            GGML_ABORT("llama_detokenize failed\n");
170        }
171        if (size == 0) {
172            size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
173            if (size < 0) {
174                GGML_ABORT("llama_detokenize failed\n");
175            }
176            if (size != 0) {
177                *dp = '\xff';  // special token prefix marker
178                size += 1;
179            }
180        }
181
182        token_lens[i] = size;
183        offset += size;
184    }
185
186    LlgTokenizerInit tinit = {
187        /* .vocab_size                         = */ (uint32_t) vocab_size,
188        /* .tok_eos                            = */ (uint32_t) tok_eos,
189        /* .token_lens                         = */ token_lens,
190        /* .token_bytes                        = */ token_bytes,
191        /* .tokenizer_json                     = */ nullptr,
192        /* .tokenize_assumes_string            = */ true,
193        /* .tokenize_fn                        = */ llama_sampler_llg_tokenize_fn,
194        /* .use_approximate_greedy_tokenize_fn = */ false,
195        /* .tokenize_user_data                 = */ vocab,
196        /* .slices                             = */ nullptr,
197    };
198
199    char           error_buffer[1024];
200    LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
201
202    delete[] token_bytes;
203    delete[] token_lens;
204
205    if (tokenizer == nullptr) {
206        LOG_ERR("llg tokenizer error: %s\n", error_buffer);
207        return tokenizer;
208    }
209
210    if (tokenizer_cache) {
211        llg_free_tokenizer(tokenizer_cache);
212    }
213    vocab_cache     = vocab;
214    tokenizer_cache = tokenizer;
215
216    return llg_clone_tokenizer(tokenizer_cache);
217}
218
219llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
220                                       const char * grammar_data) {
221    auto * ctx = new llama_sampler_llg;
222
223    if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
224        auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
225        *ctx           = {
226            /* .vocab        = */ vocab,
227            /* .grammar_kind = */ grammar_kind,
228            /* .grammar_data = */ grammar_data,
229            /* .tokenizer    = */ tokenizer,
230            /* .grammar      = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
231        };
232        if (ctx->grammar) {
233            GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
234                        llg_matcher_get_mask_byte_size(ctx->grammar));
235        }
236    } else {
237        *ctx = {
238            /* .vocab        = */ vocab,
239            /* .grammar_kind = */ {},
240            /* .grammar_data = */ {},
241            /* .tokenizer    = */ nullptr,
242            /* .grammar      = */ nullptr,
243        };
244    }
245
246    return llama_sampler_init(
247        /* .iface = */ &llama_sampler_llg_i,
248        /* .ctx   = */ ctx);
249}
250
251#else
252
253llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
254    LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
255    return nullptr;
256}
257
258#endif  // LLAMA_USE_LLGUIDANCE