1#include "arg.h"
  2#include "common.h"
  3#include "sampling.h"
  4#include "log.h"
  5#include "llama.h"
  6
  7#include <cstdio>
  8#include <string>
  9#include <vector>
 10#include <algorithm>
 11
 12struct ngram_data {
 13    bool active = false;
 14
 15    llama_seq_id seq_id = -1;
 16
 17    std::vector<int> i_batch;
 18
 19    std::vector<llama_token> tokens;
 20};
 21
 22// n-gram container
 23struct ngram_container {
 24    ngram_container(int n_vocab, int N, int G) {
 25        cnt.resize(n_vocab);
 26        head.resize(n_vocab);
 27        tokens.resize(n_vocab * G * (N - 1));
 28    }
 29
 30    int n_total = 0;
 31
 32    std::vector<int> cnt;
 33    std::vector<int> head;
 34
 35    // [n_vocab][G][N - 1]
 36    // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
 37    std::vector<llama_token> tokens;
 38};
 39
 40int main(int argc, char ** argv) {
 41    common_params params;
 42
 43    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
 44        return 1;
 45    }
 46
 47    common_init();
 48
 49    const int W = 15; // lookahead window
 50    const int N = 5;  // n-gram size
 51    const int G = 15; // max verification n-grams
 52
 53    // lookahead requires W + G + 1 sequences for parallel Jacobi decoding
 54    params.n_parallel = W + G + 1;
 55
 56    // unified KV cache is required for coupled sequences in batch splitting
 57    params.kv_unified = true;
 58
 59    // init llama.cpp
 60    llama_backend_init();
 61    llama_numa_init(params.numa);
 62
 63    // load the target model
 64    auto llama_init = common_init_from_params(params);
 65
 66    auto * model = llama_init->model();
 67    auto * ctx   = llama_init->context();
 68
 69    auto * mem = llama_get_memory(ctx);
 70
 71    const llama_vocab * vocab = llama_model_get_vocab(model);
 72
 73    // Tokenize the prompt
 74    std::vector<llama_token> inp;
 75    std::vector<llama_token> all;
 76
 77    inp = common_tokenize(ctx, params.prompt, true, true);
 78    all = inp;
 79
 80    const int max_context_size     = llama_n_ctx(ctx);
 81    const int max_tokens_list_size = max_context_size - 4;
 82
 83    if ((int) inp.size() > max_tokens_list_size) {
 84        LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
 85        return 1;
 86    }
 87
 88    LOG("\n\n");
 89
 90    for (auto id : inp) {
 91        LOG("%s", common_token_to_piece(ctx, id).c_str());
 92    }
 93
 94    fflush(stderr);
 95
 96    const int n_input = inp.size();
 97
 98    const auto t_enc_start = ggml_time_us();
 99
100    // eval the prompt
101    llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
102    llama_decode(ctx, llama_batch_get_one(&inp.back(),           1));
103
104    for (int s = 1; s < W + G + 1; ++s) {
105        llama_memory_seq_cp(mem, 0, s, -1, -1);
106    }
107
108    const auto t_enc_end = ggml_time_us();
109
110    int n_predict = 0;
111    int n_accept  = 0;
112
113    int n_past = inp.size();
114
115    llama_token id = 0;
116
117    // used to determine end of generation
118    bool has_eos = false;
119
120    // for each decoded batch, we have at most W + G + 1 distinct sequences:
121    // seq_id == 0           : the current input token
122    // seq_id [1, W]         : tokens from the past N - 1 Jacobi iterations
123    // seq_id [W + 1, W + G] : verification n-grams
124    llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
125
126    // target model sampling context
127    struct common_sampler * smpl = common_sampler_init(model, params.sampling);
128
129    // verification n-grams
130    std::vector<ngram_data> ngrams_cur(G);
131
132    // tokens for the past N - 1 Jacobi iterations
133    std::vector<llama_token> tokens_j_prev(W);
134    std::vector<std::vector<llama_token>> tokens_j(N - 1);
135    for (int j = 0; j < N - 1; j++) {
136        tokens_j[j].resize(W);
137
138        for (int i = 0; i < W; i++) {
139            // there are different ways to init these tokens
140            if (0) {
141                // initialize randomly from the prompt tokens
142                tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
143            } else {
144                // initialize with a sequence of increasing numbers
145                tokens_j[j][i] = 100 + i;
146            }
147        }
148    }
149
150    std::vector<llama_seq_id> seq_id_look;
151
152    // the input token belongs both to all sequences
153    std::vector<llama_seq_id> seq_id_all(W + G + 1);
154    for (int i = 0; i < W + G + 1; i++) {
155        seq_id_all[i] = i;
156    }
157
158    // here we keep adding new n-grams as we go
159    ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
160
161    const auto t_dec_start = ggml_time_us();
162
163    // sample first token
164    {
165        id = common_sampler_sample(smpl, ctx, 0);
166
167        common_sampler_accept(smpl, id, true);
168
169        {
170            const std::string token_str = common_token_to_piece(ctx, id);
171
172            LOG("%s", token_str.c_str());
173            fflush(stdout);
174        }
175    }
176
177    while (true) {
178        // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
179        //
180        // Example for W = 5, N = 4, G = 2:
181        // (I = input, L = lookahead, V = verification)
182        //
183        // Batch:  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
184        // T:        -2 -2 -2 -2 -1 -1 -1 -1 -1  0  0  0  0  0  0
185        // Info:   I  L  L  L  L  L  L  L  L  L  L  L  L  L  L  V  V  V  V  V  V
186        // Pos:    0  1  2  3  4  1  2  3  4  5  2  3  4  5  6  1  2  3  1  2  3   (+ n_past)
187        // Logits: 1  0  0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  1  1  1  1
188        // ---------------------------------------------------------------------
189        // Seq:    0
190        //         1              1              1
191        //         2  2              2              2
192        //         3  3  3              3              3
193        //         4  4  4  4              4              4
194        //         5  5  5  5  5              5              5
195        //         6                                            6  6  6
196        //         7                                                     7  7  7
197        // ---------------------------------------------------------------------
198        //                                       |  |  |  |  |  |  |  |  |  |  |
199        //                                       V  V  V  V  V  |  |  |  |  |  |
200        //                                         j_tokens     |  |  |  |  |  |
201        //                                                      V  V  V  V  V  V
202        //                                                             id
203        {
204            common_batch_clear(batch);
205
206            // current token - first token of the first level
207            common_batch_add(batch, id, n_past, seq_id_all, true);
208
209            // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
210            {
211                const int g_cur = ngrams_observed.cnt[id];
212
213                ngrams_cur.resize(g_cur);
214                for (int g = 0; g < g_cur; g++) {
215                    ngrams_cur[g].active = true;
216                    ngrams_cur[g].tokens.resize(N);
217                    ngrams_cur[g].i_batch.resize(N);
218                    ngrams_cur[g].seq_id = W + 1 + g;
219                    ngrams_cur[g].i_batch[0] = 0;
220                    ngrams_cur[g].tokens [0] = id;
221                }
222
223                for (int j = 0; j < N - 1; j++) {
224                    for (int g = 0; g < g_cur; g++) {
225                        const int idx = id*(N - 1)*G + g*(N - 1);
226
227                        const llama_token t = ngrams_observed.tokens[idx + j];
228
229                        ngrams_cur[g].tokens [j + 1] = t;
230                        ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
231
232                        common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
233                    }
234                }
235            }
236
237            // fill the remaining W - 1 tokens for the first level
238            for (int i = 1; i < W; i++) {
239                seq_id_look.resize(W - i);
240                for (int j = 0; j < W - i; j++) {
241                    seq_id_look[j] = i + j + 1;
242                }
243
244                common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
245            }
246
247            // fill the rest of the levels
248            for (int j = 1; j < N - 1; j++) {
249                for (int i = 0; i < W; i++) {
250                    common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
251                }
252            }
253        }
254
255        if (llama_decode(ctx, batch) != 0) {
256            LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
257            return 1;
258        }
259
260        int seq_id_best = 0;
261
262        for (int v = 0; v < N; ++v) {
263            int i_batch = 0;
264
265            // if no active ngrams are left, it means the sampled token does not pass the verification
266            if (v > 0) {
267                for (int g = 0; g < (int) ngrams_cur.size(); g++) {
268                    if (ngrams_cur[g].active) {
269                        i_batch = ngrams_cur[g].i_batch[v];
270                        seq_id_best = ngrams_cur[g].seq_id;
271
272                        ++n_accept;
273                        break;
274                    }
275                }
276
277                // no more matches -> create a new batch
278                if (i_batch == 0) {
279                    break;
280                }
281            }
282
283            // sample the next token
284            id = common_sampler_sample(smpl, ctx, i_batch);
285
286            common_sampler_accept(smpl, id, true);
287
288            // print
289            {
290                const std::string token_str = common_token_to_piece(ctx, id);
291
292                if (v == 0) {
293                    LOG("%s", token_str.c_str());
294                } else {
295                    // print light cyan
296                    LOG("\033[0;96m%s\033[0m", token_str.c_str());
297                }
298                fflush(stdout);
299
300                if (llama_vocab_is_eog(vocab, id)) {
301                    has_eos = true;
302                }
303
304                all.push_back(id);
305            }
306
307            ++n_predict;
308            ++n_past;
309
310            if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
311                break;
312            }
313
314            // verify across active n-grams
315            for (int g = 0; g < (int) ngrams_cur.size(); g++) {
316                if (ngrams_cur[g].active) {
317                    if (v == N - 1) {
318                        ngrams_cur[g].active = false;
319                    } else {
320                        if (id != ngrams_cur[g].tokens[v + 1]) {
321                            ngrams_cur[g].active = false;
322                        }
323                    }
324                }
325            }
326
327            // print known n-grams starting with token id (debug)
328            if (0 && v == 0) {
329                if (ngrams_observed.cnt[id] > 0) {
330                    LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str());
331                }
332
333                for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
334                    LOG("   - ngram %2d: ", i);
335
336                    const int idx = id*(N - 1)*G + i*(N - 1);
337
338                    for (int j = 0; j < N - 1; j++) {
339                        const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
340
341                        LOG("%s", token_str.c_str());
342                    }
343
344                    LOG("\n");
345                }
346            }
347
348            // update lookahead tokens
349            {
350                for (int i = 0; i < W; i++) {
351                    tokens_j_prev[i] = tokens_j[0][i];
352                }
353
354                for (int j = 0; j < N - 2; j++) {
355                    tokens_j[j] = tokens_j[j + 1];
356                }
357
358                if (v == 0) {
359                    // sample from the last level
360                    for (int i = 0; i < W; i++) {
361                        tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
362                    }
363                } else {
364                    for (int i = 0; i < W; i++) {
365                        // there are different ways to init these tokens
366                        if (0) {
367                            // random init
368                            tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
369                        } else {
370                            // init from the previous level
371                            tokens_j[N - 2][i] = tokens_j[0][i];
372                        }
373                    }
374                }
375            }
376
377            // update observed ngrams
378            if (v == 0) {
379                // the first token of the n-gram is determined by the index in the container so it is not stored
380                std::vector<llama_token> ngram(N - 1);
381
382                // n-gram generation
383                // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
384                for (int f = 0; f < W; ++f) {
385                    const int ft = tokens_j_prev[f]; // first token of the n-gram
386
387                    for (int j = 0; j < N - 1; ++j) {
388                        ngram[j] = tokens_j[j][f];
389                    }
390
391                    // filter-out repeating n-grams
392                    {
393                        bool is_unique = true;
394
395                        for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
396                            const int idx = ft*(N - 1)*G + k*(N - 1);
397
398                            bool is_match = true;
399                            for (int j = 0; j < N - 1; ++j) {
400                                if (ngrams_observed.tokens[idx + j] != ngram[j]) {
401                                    is_match = false;
402                                    break;
403                                }
404                            }
405
406                            if (is_match) {
407                                is_unique = false;
408                                break;
409                            }
410                        }
411
412                        if (!is_unique) {
413                            continue;
414                        }
415                    }
416
417                    const int head = ngrams_observed.head[ft];
418                    const int idx  = ft*(N - 1)*G + head*(N - 1);
419
420                    for (int i = 0; i < N - 1; i++) {
421                        ngrams_observed.tokens[idx + i] = ngram[i];
422                    }
423
424                    ngrams_observed.cnt[ft]  = std::min(G, ngrams_observed.cnt[ft] + 1);
425                    ngrams_observed.head[ft] = (head + 1) % G;
426
427                    ngrams_observed.n_total++;
428                }
429            }
430        }
431
432        if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
433            break;
434        }
435
436        // KV cache management
437        // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
438        llama_memory_seq_rm(mem, -1, n_past, -1);
439
440        if (seq_id_best != 0) {
441            // if a verification token matched, we keep the best sequence and remove the rest
442            // this leads to some KV cache fragmentation
443            llama_memory_seq_keep(mem, seq_id_best);
444            llama_memory_seq_cp  (mem, seq_id_best, 0, -1, -1);
445            llama_memory_seq_rm  (mem, seq_id_best,    -1, -1);
446
447            for (int s = 1; s < W + G + 1; ++s) {
448                llama_memory_seq_cp(mem, 0, s, -1, -1);
449            }
450        }
451    }
452
453    auto t_dec_end = ggml_time_us();
454
455    LOG("\n\n");
456
457    LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
458    LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
459
460    LOG_INF("\n");
461    LOG_INF("W = %2d\n", W);
462    LOG_INF("N = %2d\n", N);
463    LOG_INF("G = %2d\n", G);
464    LOG_INF("\n");
465    LOG_INF("n_predict = %d\n", n_predict);
466    LOG_INF("n_accept  = %d\n", n_accept);
467
468    LOG_INF("\n");
469    common_perf_print(ctx, smpl);
470
471    common_sampler_free(smpl);
472
473    llama_batch_free(batch);
474
475    llama_backend_free();
476
477    LOG("\n\n");
478
479    return 0;
480}