1#include "speculative.h"
2
3#include "common.h"
4#include "ggml.h"
5#include "llama.h"
6#include "log.h"
7#include "ngram-cache.h"
8#include "ngram-map.h"
9#include "ngram-mod.h"
10#include "sampling.h"
11
12#include <algorithm>
13#include <cstring>
14#include <iomanip>
15#include <map>
16
17#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
18#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
19
20const std::vector<enum common_speculative_type> common_speculative_types = {
21 COMMON_SPECULATIVE_TYPE_NONE,
22 COMMON_SPECULATIVE_TYPE_DRAFT,
23 COMMON_SPECULATIVE_TYPE_EAGLE3,
24 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
25 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
26 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
27 COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
28 COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
29};
30
31const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
32 {"none", COMMON_SPECULATIVE_TYPE_NONE},
33 {"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
34 {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
35 {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
36 {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
37 {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
38 {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
39 {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
40};
41
42struct common_speculative_config {
43 common_speculative_type type;
44 common_params_speculative params;
45
46 common_speculative_config(common_speculative_type t,
47 const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {}
48};
49
50static bool common_speculative_are_compatible(
51 const llama_model * model_tgt,
52 const llama_model * model_dft) {
53 const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
54 const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
55
56 const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
57 LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
58
59 const bool vocab_type_dft = llama_vocab_type(vocab_dft);
60 LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
61
62 if (vocab_type_tgt != vocab_type_dft) {
63 LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
64 LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
65 return false;
66 }
67
68 if (
69 llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
70 llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
71 llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
72 llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
73 ) {
74 LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
75 return false;
76 }
77
78 {
79 const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
80 const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
81 const int vocab_diff = n_vocab_tgt > n_vocab_dft
82 ? n_vocab_tgt - n_vocab_dft
83 : n_vocab_dft - n_vocab_tgt;
84
85 if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
86 LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
87 LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
88 n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
89 return false;
90 }
91
92 for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
93 const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
94 const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
95
96 if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
97 LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
98 LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
99 common_token_to_piece(vocab_tgt, i).c_str(),
100 common_token_to_piece(vocab_dft, i).c_str());
101 return false;
102 }
103 }
104 }
105
106 return true;
107}
108
109// state of an implementation of speculative decoding
110//
111// each implementation has a unique type and a state that is implementation-specific
112// in a subclass of common_speculative_state
113struct common_speculative_state {
114 const enum common_speculative_type type;
115
116 size_t n_call_begin = 0; // number of times this implementation was called for refresh.
117 size_t n_call_draft = 0; // number of times this implementation was called for generation.
118 size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
119
120 size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
121 size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
122 size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
123 size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
124
125 // TODO: track performance of most recent calls
126 const bool gen_perf = true; // whether to generate performance stats.
127
128 int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds.
129 int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
130 int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
131
132 common_speculative_state(enum common_speculative_type type) : type(type) {}
133
134 virtual ~common_speculative_state() = default;
135
136 virtual void begin(const llama_tokens & prompt) = 0;
137
138 virtual void draft(
139 const common_params_speculative & params,
140 const llama_tokens & prompt_tgt,
141 llama_token id_last,
142 llama_tokens & result) = 0;
143
144 virtual void accept(uint16_t n_accepted) = 0;
145};
146
147struct common_speculative_state_draft : public common_speculative_state {
148 llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
149 llama_context * ctx_dft;
150
151 common_sampler * smpl;
152
153 llama_batch batch;
154 llama_tokens prompt_dft;
155
156 bool vocab_cmpt = true; // whether retokenization is needed
157 std::unordered_map<std::string, std::string> vocab_map;
158
159 common_speculative_state_draft(
160 enum common_speculative_type type,
161 llama_context * ctx_tgt,
162 llama_context * ctx_dft,
163 const std::vector<std::pair<std::string, std::string>> & replacements)
164 : common_speculative_state(type)
165 , ctx_tgt(ctx_tgt)
166 , ctx_dft(ctx_dft)
167 {
168 batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
169 smpl = nullptr;
170
171 // TODO: optimize or pass from outside?
172 // {
173 // common_params_sampling params;
174 // params.no_perf = false;
175 //
176 // params.top_k = 40;
177 // params.top_p = 0.9;
178 //
179 // params.samplers = {
180 // COMMON_SAMPLER_TYPE_TOP_K,
181 // COMMON_SAMPLER_TYPE_TOP_P,
182 // COMMON_SAMPLER_TYPE_INFILL,
183 // };
184 //
185 // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
186 // }
187 {
188 common_params_sampling params;
189 params.no_perf = false;
190 params.top_k = 10;
191 params.samplers = {
192 COMMON_SAMPLER_TYPE_TOP_K,
193 };
194
195 smpl = common_sampler_init(llama_get_model(ctx_dft), params);
196 }
197
198 vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
199 LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
200
201 if (!vocab_cmpt) {
202 LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
203
204 for (const auto & pair : replacements) {
205 vocab_map[pair.first] = pair.second;
206 }
207 }
208 }
209
210 ~common_speculative_state_draft() override {
211 llama_perf_context_print(ctx_dft);
212
213 llama_free(ctx_dft);
214
215 common_sampler_free(smpl);
216
217 llama_batch_free(batch);
218 }
219
220 void begin(const llama_tokens & prompt) override {
221 GGML_UNUSED(prompt);
222 }
223
224 void draft(
225 const common_params_speculative & params,
226 const llama_tokens & prompt_tgt,
227 llama_token id_last,
228 llama_tokens & result) override {
229 auto * spec = this;
230
231 auto & batch = spec->batch;
232 auto & ctx_tgt = spec->ctx_tgt;
233 auto & ctx_dft = spec->ctx_dft;
234 auto & smpl = spec->smpl;
235 auto & prompt_dft = spec->prompt_dft;
236
237 auto * mem_dft = llama_get_memory(ctx_dft);
238
239 int reuse_i = 0;
240 int reuse_n = 0;
241
242 const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
243
244 llama_tokens prompt_cnv;
245 if (!spec->vocab_cmpt) {
246 std::string text;
247
248 text = common_detokenize(ctx_tgt, prompt_tgt, true);
249 text = replace_to_dft(text);
250
251 LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
252
253 prompt_cnv = common_tokenize(ctx_dft, text, false, true);
254
255 // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
256 const auto * model_tgt = llama_get_model(ctx_tgt);
257 const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
258
259 int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
260 GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
261
262 text.resize(-n_chars);
263 llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
264 text = replace_to_dft(text);
265
266 LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
267 id_last = common_tokenize(ctx_dft, text, false, true)[0];
268 }
269
270 const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
271
272 const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
273
274 // reuse as much as possible from the old draft context
275 // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
276 for (int i = 0; i < (int) prompt_dft.size(); ++i) {
277 int cur = 0;
278 while (i_start + cur < (int) prompt_cur.size() &&
279 i + cur < (int) prompt_dft.size() &&
280 prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
281 cur++;
282 }
283
284 if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
285 reuse_i = i;
286 reuse_n = cur;
287 }
288 }
289
290 LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
291
292 result.clear();
293 result.reserve(params.n_max);
294
295 if (reuse_n == 0) {
296 llama_memory_clear(mem_dft, false);
297 prompt_dft.clear();
298 } else {
299 // this happens when a previous draft has been discarded (for example, due to being too small), but the
300 // target model agreed with it. in this case, we simply pass back the previous results to save compute
301 if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
302 for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
303 result.push_back(prompt_dft[i]);
304
305 if (params.n_max <= (int) result.size()) {
306 break;
307 }
308 }
309
310 return;
311 }
312
313 if (reuse_i > 0) {
314 llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
315 llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
316
317 prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
318 }
319
320 if (reuse_n < (int) prompt_dft.size()) {
321 llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
322 prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
323 }
324 }
325
326 // prepare a batch to evaluate any new tokens in the prompt
327 common_batch_clear(batch);
328
329 for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
330 //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
331 common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
332
333 prompt_dft.push_back(prompt_cur[i]);
334 }
335
336 // we should rarely end-up here during normal decoding
337 if (batch.n_tokens > 0) {
338 //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
339
340 llama_decode(ctx_dft, batch);
341 }
342
343 const llama_pos n_past = prompt_dft.size();
344
345 LOG_DBG("%s: n_past = %d\n", __func__, n_past);
346
347 common_batch_clear(batch);
348 common_batch_add (batch, id_last, n_past, { 0 }, true);
349
350 prompt_dft.push_back(id_last);
351
352 LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
353
354 llama_decode(ctx_dft, batch);
355
356 common_sampler_reset(smpl);
357
358 // sample n_draft tokens from the draft model
359 for (int i = 0; i < params.n_max; ++i) {
360 common_batch_clear(batch);
361
362 common_sampler_sample(smpl, ctx_dft, 0, true);
363
364 const auto * cur_p = common_sampler_get_candidates(smpl, true);
365
366 for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
367 LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
368 k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
369 }
370
371 // add drafted token for each sequence
372 const llama_token id = cur_p->data[0].id;
373
374 common_sampler_accept(smpl, id, true);
375
376 result.push_back(id);
377
378 if (params.n_max <= (int) result.size()) {
379 break;
380 }
381
382 // only collect very high-confidence draft tokens
383 if (cur_p->data[0].p < params.p_min) {
384 break;
385 }
386
387 common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
388
389 // evaluate the drafted tokens on the draft model
390 llama_decode(ctx_dft, batch);
391
392 prompt_dft.push_back(id);
393 }
394
395 if (!spec->vocab_cmpt) {
396 std::string detokenized = common_detokenize(ctx_dft, result, true);
397 detokenized = replace_to_tgt(detokenized);
398 LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
399 result = common_tokenize(ctx_tgt, detokenized, false, true);
400 if (result.size() > (size_t)params.n_max) {
401 result.resize(params.n_max);
402 }
403 }
404 }
405
406 void accept(uint16_t n_accepted) override {
407 // noop
408 GGML_UNUSED(n_accepted);
409 }
410
411 std::string replace_to_dft(const std::string & input) const {
412 std::string result = input;
413
414 for (const auto & pair : this->vocab_map) {
415 size_t pos = result.find(pair.first);
416 while (pos != std::string::npos) {
417 result.replace(pos, pair.first.length(), pair.second);
418 pos = result.find(pair.first, pos + pair.second.length());
419 }
420 }
421
422 return result;
423 }
424
425 std::string replace_to_tgt(const std::string & input) const {
426 std::string result = input;
427
428 for (const auto & pair : this->vocab_map) {
429 size_t pos = result.find(pair.second);
430 while (pos != std::string::npos) {
431 result.replace(pos, pair.second.length(), pair.first);
432 pos = result.find(pair.second, pos + pair.first.length());
433 }
434 }
435
436 return result;
437 }
438};
439
440struct common_speculative_state_eagle3 : public common_speculative_state {
441 common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {}
442
443 void begin(const llama_tokens & prompt) override {
444 GGML_UNUSED(prompt);
445 }
446
447 void draft(
448 const common_params_speculative & params,
449 const llama_tokens & prompt_tgt,
450 llama_token id_last,
451 llama_tokens & draft_tokens) override {
452 // TODO: implement
453 GGML_UNUSED(params);
454 GGML_UNUSED(prompt_tgt);
455 GGML_UNUSED(id_last);
456 GGML_UNUSED(draft_tokens);
457 }
458
459 void accept(uint16_t n_accepted) override {
460 // noop
461 GGML_UNUSED(n_accepted);
462 }
463};
464
465// state of self-speculation (simple implementation, not ngram-map)
466struct common_speculative_state_ngram_simple : public common_speculative_state {
467 common_ngram_simple_config config;
468
469 common_speculative_state_ngram_simple(
470 enum common_speculative_type type,
471 common_ngram_simple_config config)
472 : common_speculative_state(type), config(config) {}
473
474 void begin(const llama_tokens & prompt) override {
475 GGML_UNUSED(prompt);
476 }
477
478 void draft(
479 const common_params_speculative & params,
480 const llama_tokens & prompt_tgt,
481 llama_token id_last,
482 llama_tokens & result) override {
483
484 result = common_ngram_simple_draft(config, prompt_tgt, id_last);
485 GGML_UNUSED(params);
486 }
487
488 void accept(uint16_t n_accepted) override {
489 // noop
490 GGML_UNUSED(n_accepted);
491 }
492};
493
494struct common_speculative_state_ngram_map_k : public common_speculative_state {
495 // draft ngram map for speculative decoding without draft model
496 common_ngram_map map;
497
498 common_speculative_state_ngram_map_k(
499 enum common_speculative_type type,
500 common_ngram_map map)
501 : common_speculative_state(type), map(std::move(map)) {}
502
503 void begin(const llama_tokens & prompt) override {
504 common_ngram_map_begin(map, prompt);
505 }
506
507 void draft(
508 const common_params_speculative & params,
509 const llama_tokens & prompt_tgt,
510 llama_token id_last,
511 llama_tokens & result) override {
512 common_ngram_map_draft(map, prompt_tgt, id_last, result);
513 GGML_UNUSED(params);
514 }
515
516 void accept(uint16_t n_accepted) override {
517 common_ngram_map_accept(map, n_accepted);
518 }
519};
520
521struct common_speculative_state_ngram_mod : public common_speculative_state {
522 common_ngram_mod & mod;
523
524 // the last position in the prompt that was added to the ngram container
525 size_t i_last = 0;
526
527 // length of the last drafted nโgram (number of tokens returned by draft)
528 size_t n_draft_last = 0;
529
530 // consecutive accept rounds with low acceptance fraction (< 0.5)
531 int n_low = 0;
532
533 // enable trace logging if LLAMA_TRACE is set
534 const bool verbose;
535
536 common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod)
537 : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) {
538 static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
539 }
540
541 void begin(const llama_tokens & prompt) override {
542 i_last = 0;
543
544 n_draft_last = 0;
545
546 const size_t n = mod.get_n();
547
548 if (prompt.size() < n) {
549 return;
550 }
551
552 for (size_t i = 0; i < prompt.size() - n; ++i) {
553 mod.add(prompt.data() + i);
554 }
555
556 i_last = prompt.size() - n;
557
558 const double f = (double)mod.get_used() / (double)mod.size();
559 LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
560
561 constexpr double f_thold = 0.25;
562 if (f > f_thold) {
563 LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
564
565 mod.reset();
566 }
567 }
568
569 void draft(
570 const common_params_speculative & params,
571 const llama_tokens & prompt_tgt,
572 llama_token id_last,
573 llama_tokens & result) override {
574 GGML_UNUSED(params);
575
576 n_draft_last = 0;
577
578 const size_t cur_len = prompt_tgt.size();
579 if (cur_len < mod.get_n()) {
580 return;
581 }
582
583 const size_t n = mod.get_n();
584
585 // add new ngrams in chunks
586 if (i_last + 32 < cur_len) {
587 for (size_t i = i_last; i < cur_len - n; ++i) {
588 mod.add(prompt_tgt.data() + i);
589 }
590
591 i_last = cur_len - n;
592 }
593
594 result.resize(n + params.n_max);
595 for (size_t i = 0; i < n - 1; ++i) {
596 result[i] = prompt_tgt[cur_len - n + 1 + i];
597 }
598 result[n - 1] = id_last;
599
600 for (int i = 0; i < params.n_max; ++i) {
601 const llama_token token = mod.get(result.data() + i);
602 if (token == common_ngram_mod::EMPTY) {
603 if (i < params.n_min) {
604 result.clear();
605 return;
606 }
607
608 result.resize(n + i);
609 break;
610 }
611 result[n + i] = token;
612 }
613
614 // only return the m tokens that were drafted
615 for (size_t i = 0; n + i < result.size(); ++i) {
616 result[i] = result[n + i];
617 }
618 result.resize(result.size() - n);
619
620 // store length of drafted nโgram for later acceptance analysis
621 n_draft_last = result.size();
622 }
623
624 void accept(uint16_t n_accepted) override {
625 if (verbose) {
626 LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
627 }
628
629 // compute acceptance fraction if we have a recorded draft length
630 if (n_draft_last > 0) {
631 const double f_acc = (double)n_accepted / (double)n_draft_last;
632 if (f_acc < 0.5) {
633 n_low++;
634 if (n_low >= 3) {
635 LOG_WRN("%s: low acceptance streak (%d) โ resetting ngram_mod\n", __func__, n_low);
636
637 mod.reset();
638 n_low = 0;
639 }
640 } else {
641 n_low = 0;
642 }
643 }
644 }
645};
646
647struct common_speculative_state_ngram_cache : public common_speculative_state {
648 uint16_t n_draft;
649 bool save_dynamic;
650 bool save_static;
651
652 common_ngram_cache ngram_cache_context;
653 common_ngram_cache ngram_cache_dynamic;
654 common_ngram_cache ngram_cache_static;
655
656 size_t cache_size = 0; // number of tokens in n-gram cache
657
658 common_speculative_state_ngram_cache(
659 const enum common_speculative_type type,
660 const std::string & path_static,
661 const std::string & path_dynamic,
662 uint16_t n_draft,
663 bool save_dynamic,
664 bool save_static)
665 : common_speculative_state(type)
666 , n_draft(n_draft)
667 , save_dynamic(save_dynamic)
668 , save_static(save_static)
669 {
670 if (!path_static.empty()) {
671 try {
672 ngram_cache_static = common_ngram_cache_load(path_static);
673 } catch (...) {
674 LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
675 GGML_ABORT("Couldn't read static lookup cache");
676 }
677 }
678
679 if (!path_dynamic.empty()) {
680 try {
681 ngram_cache_dynamic = common_ngram_cache_load(path_dynamic);
682 } catch (...) {
683 LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
684 GGML_ABORT("Couldn't read dynamic lookup cache");
685 }
686 }
687 }
688
689 void begin(const llama_tokens & prompt) override {
690 GGML_UNUSED(prompt);
691 }
692
693 void draft(
694 const common_params_speculative & params,
695 const llama_tokens & prompt_tgt,
696 llama_token id_last,
697 llama_tokens & result) override {
698 GGML_UNUSED(params);
699
700 if (cache_size < prompt_tgt.size() + 1) {
701 llama_tokens tokens_new;
702 tokens_new.reserve(prompt_tgt.size() + 1 - cache_size);
703 for (size_t j = cache_size; j < prompt_tgt.size(); ++j) {
704 tokens_new.push_back(prompt_tgt[j]);
705 }
706 tokens_new.push_back(id_last); // add the last token
707
708 // Update context ngram cache with new prompt_tgt:
709 common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
710 tokens_new, tokens_new.size(), false);
711 cache_size = prompt_tgt.size() + 1;
712 }
713
714 llama_tokens inp;
715 inp.reserve(prompt_tgt.size() + 1);
716 for (size_t j = 0; j < prompt_tgt.size(); ++j) {
717 inp.push_back(prompt_tgt[j]);
718 }
719 inp.push_back(id_last);
720
721 result.push_back(id_last);
722
723 common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
724 ngram_cache_context,
725 ngram_cache_dynamic,
726 ngram_cache_static);
727
728 if (result.size() > 0) {
729 // delete first token in result (which is the id_last token)
730 result.erase(result.begin());
731 }
732 }
733
734 void accept(uint16_t n_accepted) override {
735 // TODO: noop
736 GGML_UNUSED(n_accepted);
737 }
738};
739
740struct common_speculative {
741 std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
742 common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
743};
744
745static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
746 uint16_t size_key = config.params.ngram_size_n;
747 uint16_t size_value = config.params.ngram_size_m;
748 bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
749 uint16_t min_hits = config.params.ngram_min_hits;
750
751 return common_ngram_map(size_key, size_value, key_only, min_hits);
752}
753
754static common_speculative_state_ngram_cache create_state_ngram_cache(
755 const std::string & path_static, const std::string & path_dynamic,
756 const common_speculative_config & config) {
757 uint16_t n_draft = 8; // TODO get from config?
758
759 // TODO bool param in common/common.h to set save_static/save_dynamic?
760 bool save_static = false;
761 bool save_dynamic = false;
762
763 common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
764
765 return state;
766}
767
768std::string common_speculative_type_name_str() {
769 std::string result;
770 for (size_t i = 0; i < common_speculative_types.size(); i++) {
771 if (i > 0) {
772 result += ", ";
773 }
774 result += common_speculative_type_to_str(common_speculative_types[i]);
775 }
776 return result;
777}
778
779std::string common_speculative_type_to_str(enum common_speculative_type type) {
780 switch (type) {
781 case COMMON_SPECULATIVE_TYPE_NONE: return "none";
782 case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
783 case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
784 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
785 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
786 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
787 case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
788 case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
789 default: return "unknown";
790 }
791}
792
793enum common_speculative_type common_speculative_type_from_name(const std::string & name) {
794 const auto it = common_speculative_type_from_name_map.find(name);
795 if (it == common_speculative_type_from_name_map.end()) {
796 return COMMON_SPECULATIVE_TYPE_COUNT;
797 }
798 return it->second;
799}
800
801bool common_speculative_is_compat(llama_context * ctx_tgt) {
802 auto * mem = llama_get_memory(ctx_tgt);
803 if (mem == nullptr) {
804 return false;
805 }
806
807 bool res = true;
808
809 llama_memory_clear(mem, true);
810
811 // eval 2 tokens to check if the context is compatible
812 std::vector<llama_token> tmp;
813 tmp.push_back(0);
814 tmp.push_back(0);
815
816 int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
817 if (ret != 0) {
818 LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
819 res = false;
820 goto done;
821 }
822
823 // try to remove the last tokens
824 if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
825 LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
826 res = false;
827 goto done;
828 }
829
830done:
831 llama_memory_clear(mem, true);
832 llama_synchronize(ctx_tgt);
833
834 return res;
835}
836
837// initialization of the speculative decoding system
838//
839common_speculative * common_speculative_init(
840 common_params_speculative & params,
841 llama_context * ctx_tgt) {
842 llama_context * ctx_dft = nullptr;
843 if (params.model_dft) {
844 ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
845 if (ctx_dft == nullptr) {
846 LOG_ERR("%s", "failed to create draft context\n");
847 return nullptr;
848 }
849 }
850
851 // Compute the implementations to use based on the config and their order of preference
852 std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
853 {
854 bool has_draft = !params.mparams_dft.path.empty();
855 bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
856
857 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
858 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
859 bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
860 bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
861 bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
862
863 // In a more complex implementation we could use the same implementation but with different parameters.
864 // This was initially used in PR-18471 but removed to simplify the code.
865 if (has_ngram_simple) {
866 // This implementation can guess a lot of tokens without any draft model.
867 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params));
868 }
869 if (has_ngram_map_k) {
870 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params));
871 }
872 if (has_ngram_map_k4v) {
873 // This implementation can guess tokens with high acceptance rate but is more expensive.
874 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
875 }
876 if (has_ngram_mod) {
877 // shared instance for all speculative decoding contexts
878 if (!params.ngram_mod) {
879 params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
880
881 LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
882 params.ngram_size_n, params.ngram_mod->size(),
883 (float)(params.ngram_mod->size_bytes())/1024/1024);
884
885 if (params.ngram_size_n < 16) {
886 LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
887 }
888 }
889
890 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params));
891 }
892 if (has_ngram_cache) {
893 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
894 }
895 if (has_draft) {
896 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
897 }
898 if (has_draft_eagle3) {
899 configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
900 }
901 }
902
903 std::vector<std::unique_ptr<common_speculative_state>> impls = {};
904
905 for (const common_speculative_config & config : configs) {
906 LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str());
907 switch (config.type) {
908 case COMMON_SPECULATIVE_TYPE_NONE:
909 break;
910 case COMMON_SPECULATIVE_TYPE_DRAFT: {
911 impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
912 /* .ctx_tgt = */ ctx_tgt,
913 /* .ctx_dft = */ ctx_dft,
914 /* .replacements = */ params.replacements
915 ));
916 break;
917 }
918 case COMMON_SPECULATIVE_TYPE_EAGLE3: {
919 impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
920 break;
921 }
922 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
923 common_ngram_map ngram_map = get_common_ngram_map(config);
924
925 uint16_t ngram_size_key = ngram_map.size_key;
926 uint16_t mgram_size_value = ngram_map.size_value;
927
928 auto config_simple = common_ngram_simple_config {
929 /* .size_ngram = */ ngram_size_key,
930 /* .size_mgram = */ mgram_size_value
931 };
932 auto state = std::make_unique<common_speculative_state_ngram_simple>(
933 /* .type = */ config.type,
934 /* .state = */ config_simple
935 );
936 impls.push_back(std::move(state));
937 break;
938 }
939 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
940 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
941 impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
942 (config.type),
943 get_common_ngram_map(config)
944 ));
945 break;
946 }
947 case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
948 GGML_ASSERT(config.params.ngram_mod);
949 impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
950 break;
951 }
952 case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
953 auto state = create_state_ngram_cache(
954 params.lookup_cache_static, params.lookup_cache_dynamic, config);
955 impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
956 break;
957 }
958 default:
959 break;
960 }
961 }
962
963 if (impls.empty()) {
964 LOG_WRN("%s", "no implementations specified for speculative decoding\n");
965 return nullptr;
966 }
967
968 auto * result = new common_speculative {
969 /* .impls = */ std::move(impls)
970 };
971
972 return result;
973}
974
975void common_speculative_free(common_speculative * spec) {
976 if (spec == nullptr) {
977 return;
978 }
979
980 delete spec;
981}
982
983void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) {
984 if (spec == nullptr) {
985 return;
986 }
987
988 for (auto & impl : spec->impls) {
989 common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
990 impl->begin(prompt);
991 impl->n_call_begin++;
992 }
993}
994
995llama_tokens common_speculative_draft(
996 common_speculative * spec,
997 const common_params_speculative & params,
998 const llama_tokens & prompt_tgt, // specified in target model vocab
999 llama_token id_last) {
1000 llama_tokens result;
1001
1002 spec->curr_impl = nullptr; // reset current implementation
1003
1004 for (auto & impl : spec->impls) {
1005 {
1006 common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
1007 impl->draft(params, prompt_tgt, id_last, result);
1008 impl->n_call_draft++;
1009 }
1010
1011 if (!result.empty()) {
1012 LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
1013 common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
1014 impl.get()->n_call_draft, result.size());
1015
1016 spec->curr_impl = impl.get(); // set current implementation for stats
1017 impl->n_gen_drafts++;
1018 impl->n_gen_tokens += result.size();
1019
1020 break; // We have a draft, so break out of the loop and return it.
1021 }
1022 }
1023
1024 return result;
1025}
1026
1027void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
1028 if (n_accepted == 0) {
1029 return;
1030 }
1031
1032 common_speculative_state * impl = spec->curr_impl;
1033
1034 GGML_ASSERT(impl);
1035
1036 {
1037 common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
1038 if (n_accepted > 0) {
1039 impl->n_acc_drafts++;
1040 impl->n_acc_tokens += n_accepted;
1041 }
1042
1043 impl->accept(n_accepted);
1044 impl->n_call_accept++;
1045 }
1046}
1047
1048void common_speculative_print_stats(const common_speculative * spec) {
1049 if (spec == nullptr) {
1050 return;
1051 }
1052
1053 for (const auto & impl : spec->impls) {
1054 std::string str_perf;
1055 if (impl->gen_perf) {
1056 std::ostringstream oss;
1057 oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", ";
1058 oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", ";
1059 oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0;
1060 str_perf = ", dur(b,g,a) = " + oss.str() + " ms";
1061 } else {
1062 str_perf = "";
1063 }
1064
1065 LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
1066 common_speculative_type_to_str(impl->type).c_str(),
1067 impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
1068 impl->n_gen_drafts,
1069 impl->n_acc_drafts,
1070 impl->n_gen_tokens,
1071 impl->n_acc_tokens,
1072 str_perf.c_str());
1073 }
1074}