1#pragma once
 2
 3#include "llama.h"
 4#include "common.h"
 5
 6struct common_speculative;
 7
 8// comma separated list of all types
 9std::string common_speculative_type_name_str();
10
11// convert string to type
12enum common_speculative_type common_speculative_type_from_name(const std::string & name);
13
14// convert type to string
15std::string common_speculative_type_to_str(enum common_speculative_type type);
16
17// check if the llama_context is compatible for speculative decoding
18// note: clears the memory of the context
19bool common_speculative_is_compat(llama_context * ctx_tgt);
20
21common_speculative * common_speculative_init(
22        common_params_speculative & params,
23        llama_context             * ctx_tgt);
24
25void common_speculative_free(common_speculative * spec);
26
27// optionally call once at the beginning of a new generation
28void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
29
30// sample up to n_draft tokens and add them to the batch using the draft model
31llama_tokens common_speculative_draft(
32                     common_speculative * spec,
33        const common_params_speculative & params,
34                     const llama_tokens & prompt,
35                            llama_token   id_last);
36
37// informs the speculative decoder that n_accepted tokens were accepted by the target model
38void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
39
40// print statistics about the speculative decoding
41void common_speculative_print_stats(const common_speculative * spec);