1# Speculative Decoding
2
3llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
4
5[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
6
7## Implementations
8
9The `llama-server` application supports several implementations of speculative decoding. An implementation with draft model can be mixed with an implementation without draft model.
10
11### Draft Model (`draft`)
12
13A much smaller model (called the _draft model_) generates drafts.
14A draft model is the most used approach in speculative decoding.
15
16### n-gram Cache (`ngram-cache`)
17
18An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
19A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
20
21See:
22
23- #5479, #6828, #6848
24
25### n-gram Map (`ngram-simple`, `ngram-map-*`)
26
27These implementations search the token history for patterns and use matching sequences as draft candidates.
28They require no additional model but rely on patterns that have already appeared in the generated text.
29An example to use this approach can be the rewriting of source code by a LLM.
30
31#### n-gram Map (`ngram-simple`)
32
33This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
34
35```
36llama-server [...] --spec-type ngram-simple --draft-max 64
37```
38
39#### n-gram Map Key (`ngram-map-k`)
40
41This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts.
42
43The number of accepted tokens is stored for each used n-gram.
44
45**Example:**
46```
47llama-server [...] --spec-type ngram-map-k --draft-max 64
48```
49
50#### n-gram Map Key-4-Values (`ngram-map-k4v`)
51
52This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
53
54The number of accepted tokens is stored for each used n-gram.
55
56**Example:** Server options to be used if there are a lot of longer repetitions.
57```
58llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
59```
60
61### n-gram Mod (`ngram-mod`)
62
63Add basic ngram hasher for speculative decoding:
64
65- For each ngram, compute a hash using LCG
66- For each computed hash, store the next token
67- During speculation, iteratively compute the rolling hash of the last n tokens and pick the next token from the storage
68
69Some characteristics:
70
71- Lightweight (~16 MB)
72- Constant memory and complexity
73- Can generate variable draft lengths (i.e. m is not fixed)
74
75Currently, a single hash pool is shared across all server slots, so different requests can benefit from each other.
76
77**Sample usage:**
78
79```
80# notes:
81# - small `n` are not recommended
82# - MoEs require long drafts
83# - dense models: can reduce `--draft-min` and `--draft-max`
84
85llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
86```
87
88Applications:
89
90- Iterating over a block of text/code (e.g. in llama.vim)
91- Reasoning models (when they have to repeat their thinking in the final answer)
92- Summarization
93
94Example Video:
95
96- See #19164
97
98### Differences between ngram-simple, ngram-map and ngram-mod
99
100- ngram-simple looks for a previous matching n-gram and inserts the following m-gram.
101- ngram-map-k looks for a previous matching n-gram and inserts the following m-gram but uses an internal hash-map of n-grams in the current context window.
102- ngram-mod uses a hash pool which is shared across all server slots. The hash pool is a map from n-gram hash to the next token (not the next m-gram as in ngram-map).
103
104## Command-Line Options
105
106If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
107
108```
109--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
110 (env: LLAMA_ARG_DRAFT_MAX)
111--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
112 (default: 0)
113 (env: LLAMA_ARG_DRAFT_MIN)
114[...]
115--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
116 type of speculative decoding to use when no draft model is provided
117 (default: none)
118--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
119 of lookup n-gram (default: 12)
120--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
121 of draft m-gram (default: 48)
122--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
123```
124
125### `--spec-type TYPE`
126
127Specifies a type of speculative decoding without draft model.
128
129| Type | Description |
130|------|-------------|
131| `none` | No speculative decoding (default) |
132| `ngram-cache` | Use n-gram cache lookup |
133| `ngram-simple` | Use simple n-gram pattern matching |
134| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
135| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
136| `ngram-mod` | Use basic ngram hasher for speculative decoding with shared pool |
137
138**Example:** Server-instance used to refactor source code.
139```bash
140./llama-server [...] --spec-type ngram-simple
141```
142
143### `--spec-ngram-size-n N`
144
145Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
146The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
147
148### `--spec-ngram-size-m M`
149
150Sets the size M of the draft m-gram for n-gram map based speculative decoding.
151The m-gram size determines how many tokens to draft when a match is found.
152Larger values can provide more speedup but may reduce acceptance rate.
153
154### `--spec-ngram-min-hits H`
155
156This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
157
158## Statistics
159Each speculative decoding implementation prints statistics.
160
161```
162draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
163statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
164statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
165```
166
167```
168draft acceptance rate = 0.70312 ( 90 accepted / 128 generated)
169statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
170```
171
172```
173statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms
174```
175
176
177- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations
178- `#gen drafts`: number of drafts generated by this implementation
179- `#acc drafts`: number of drafts accepted (partially) by the main model
180- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
181- `#acc tokens`: number of tokens accepted by the main model
182- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
183