1#pragma once
2
3#include "llama.h"
4#include "common.h"
5
6struct common_speculative;
7
8// comma separated list of all types
9std::string common_speculative_type_name_str();
10
11// convert string to type
12enum common_speculative_type common_speculative_type_from_name(const std::string & name);
13
14// convert type to string
15std::string common_speculative_type_to_str(enum common_speculative_type type);
16
17// check if the llama_context is compatible for speculative decoding
18// note: clears the memory of the context
19bool common_speculative_is_compat(llama_context * ctx_tgt);
20
21common_speculative * common_speculative_init(
22 common_params_speculative & params,
23 llama_context * ctx_tgt);
24
25void common_speculative_free(common_speculative * spec);
26
27// optionally call once at the beginning of a new generation
28void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
29
30// sample up to n_draft tokens and add them to the batch using the draft model
31llama_tokens common_speculative_draft(
32 common_speculative * spec,
33 const common_params_speculative & params,
34 const llama_tokens & prompt,
35 llama_token id_last);
36
37// informs the speculative decoder that n_accepted tokens were accepted by the target model
38void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
39
40// print statistics about the speculative decoding
41void common_speculative_print_stats(const common_speculative * spec);