1#include "arg.h"
  2#include "common.h"
  3#include "sampling.h"
  4#include "speculative.h"
  5#include "log.h"
  6#include "llama.h"
  7
  8#include <cstdio>
  9#include <cstring>
 10#include <string>
 11#include <vector>
 12
 13int main(int argc, char ** argv) {
 14    common_params params;
 15
 16    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
 17        return 1;
 18    }
 19
 20    if (params.n_predict < -1) {
 21        LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
 22        return 1;
 23    }
 24
 25    common_init();
 26
 27    if (params.speculative.mparams_dft.path.empty()) {
 28        LOG_ERR("%s: --model-draft is required\n", __func__);
 29        return 1;
 30    }
 31
 32    // init llama.cpp
 33    llama_backend_init();
 34    llama_numa_init(params.numa);
 35
 36    llama_model * model_tgt = NULL;
 37
 38    llama_context * ctx_tgt = NULL;
 39
 40    // load the target model
 41    auto llama_init_tgt = common_init_from_params(params);
 42
 43    model_tgt = llama_init_tgt->model();
 44    ctx_tgt   = llama_init_tgt->context();
 45
 46    const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
 47
 48    // load the draft model
 49    llama_model_ptr model_dft;
 50
 51    // TODO: simplify this logic
 52    {
 53        const auto & params_spec = params.speculative;
 54
 55        auto params_dft = params;
 56
 57        params_dft.n_parallel   = 1;
 58        params_dft.n_ctx        = params_spec.n_ctx;
 59        params_dft.n_batch      = llama_n_ctx_seq(ctx_tgt);
 60        params_dft.devices      = params_spec.devices;
 61        params_dft.model        = params_spec.mparams_dft;
 62        params_dft.n_gpu_layers = params_spec.n_gpu_layers;
 63
 64        if (params_spec.cpuparams.n_threads > 0) {
 65            params_dft.cpuparams.n_threads       = params.speculative.cpuparams.n_threads;
 66            params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
 67        }
 68
 69        params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
 70
 71        auto mparams_dft = common_model_params_to_llama(params_dft);
 72
 73        model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
 74        if (model_dft == nullptr) {
 75            LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
 76            return 1;
 77        }
 78
 79        params.speculative.model_dft = model_dft.get();
 80        params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
 81    }
 82
 83    // Tokenize the prompt
 84    std::vector<llama_token> inp;
 85    inp = common_tokenize(ctx_tgt, params.prompt, true, true);
 86
 87    if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
 88        LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
 89
 90        return 1;
 91    }
 92
 93    if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
 94        LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
 95
 96        return 1;
 97    }
 98
 99    LOG("\n\n");
100
101    for (auto id : inp) {
102        LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
103    }
104
105    int n_predict = 0;
106    int n_drafted = 0;
107    int n_accept  = 0;
108
109    // used to determine end of generation
110    bool has_eos = false;
111
112    // ================================================
113    // everything until here is standard initialization
114    // the relevant stuff for speculative decoding starts here
115
116    const auto t_enc_start = ggml_time_us();
117
118    // target model sampling context
119    struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
120
121    // eval the prompt
122    llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
123
124    // note: keep the last token separate!
125    llama_token id_last = inp.back();
126
127    // all tokens currently in the target context
128    llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
129    prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
130
131    int n_past = inp.size() - 1;
132
133    // init the speculator
134    const auto & params_spec = params.speculative;
135
136    struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
137
138    common_speculative_begin(spec, prompt_tgt);
139
140    llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
141
142    const auto t_enc_end = ggml_time_us();
143
144    const auto t_dec_start = ggml_time_us();
145
146    while (true) {
147        // optionally, generate draft tokens that can be appended to the target batch
148        //
149        // this is the most important part of the speculation. the more probable tokens that are provided here
150        // the better the performance will be. in theory, this computation can be performed asynchronously and even
151        // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
152        // from a cache or lookup tables.
153        //
154        llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
155
156        //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
157
158        // always have a token to evaluate from before - id_last
159        common_batch_clear(batch_tgt);
160        common_batch_add  (batch_tgt, id_last, n_past++, { 0 }, true);
161
162        // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
163        {
164            // do not waste time on small drafts
165            if (draft.size() < (size_t) params_spec.n_min) {
166                draft.clear();
167            }
168
169            for (size_t i = 0; i < draft.size(); ++i) {
170                common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
171            }
172
173            //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
174
175            llama_decode(ctx_tgt, batch_tgt);
176        }
177
178        // sample from the full target batch and return the accepted tokens based on the target sampler
179        //
180        // for each token to be accepted, the sampler would have to sample that same token
181        // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
182        // available logits from the batch and sample the next token until we run out of logits or the sampler
183        // disagrees with the draft
184        //
185        const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
186
187        //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
188
189        GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
190
191        n_past    += ids.size() - 1;
192        n_drafted += draft.size(); // note: we ignore the discarded small drafts
193        n_accept  += ids.size() - 1;
194        n_predict += ids.size();
195
196        // process the accepted tokens and update contexts
197        //
198        // this is the standard token post-processing that we normally do
199        // in this case, we do it for a group of accepted tokens at once
200        //
201        for (size_t i = 0; i < ids.size(); ++i) {
202            prompt_tgt.push_back(id_last);
203
204            id_last = ids[i];
205
206            if (llama_vocab_is_eog(vocab, id_last)) {
207                has_eos = true;
208                break;
209            }
210
211            const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
212
213            if (params.use_color && i + 1 < ids.size()) {
214                LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
215            } else {
216                LOG("%s", token_str.c_str());
217            }
218        }
219
220        LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
221
222        {
223            LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
224
225            llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
226        }
227
228        if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
229            break;
230        }
231    }
232
233    auto t_dec_end = ggml_time_us();
234
235    const int n_input = inp.size();
236
237    LOG("\n\n");
238
239    LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
240    LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
241
242    LOG_INF("\n");
243    LOG_INF("n_draft   = %d\n", params_spec.n_max);
244    LOG_INF("n_predict = %d\n", n_predict);
245    LOG_INF("n_drafted = %d\n", n_drafted);
246    LOG_INF("n_accept  = %d\n", n_accept);
247    LOG_INF("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
248
249    LOG_INF("\n");
250    LOG_INF("draft:\n\n");
251
252    LOG_INF("\n");
253    LOG_INF("target:\n\n");
254    common_perf_print(ctx_tgt, smpl);
255
256    llama_batch_free(batch_tgt);
257
258    common_sampler_free(smpl);
259    common_speculative_free(spec);
260
261    llama_backend_free();
262
263    LOG("\n\n");
264
265    return 0;
266}