1#pragma once
  2
  3#include "llama.h"
  4
  5#include "common.h"
  6
  7#include <string>
  8#include <vector>
  9
 10// common_sampler extends llama_sampler with additional functionality:
 11//
 12//  - grammar support
 13//  - custom sampler logic based on the parameters
 14//  - history of the last accepted tokens
 15//  - performance metrics
 16//
 17// This goal is to have a common implementation of the sampling logic shared across the examples.
 18// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
 19// complex (top-k, top-p, etc).
 20//
 21// Another example is related to the grammar. In general, the grammar constraints applied on the full
 22// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
 23// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
 24// grammar constraints are applied to the full vocabulary and the token is resampled.
 25//
 26// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
 27// be moved into the core llama library.
 28//
 29// For convenience, the common_sampler also maintains a container with the current candidate tokens.
 30// This can be used to access the probabilities of the rest of the non-sampled tokens.
 31//
 32// TODO: measure grammar performance
 33//
 34
 35struct common_sampler;
 36
 37// llama_sampler API overloads
 38
 39// note: can mutate params in some cases
 40struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
 41
 42void common_sampler_free(struct common_sampler * gsmpl);
 43
 44// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
 45void                    common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
 46void                    common_sampler_reset (struct common_sampler * gsmpl);
 47struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
 48
 49// arguments can be nullptr to skip printing
 50void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
 51
 52// get the underlying llama_sampler_chain
 53struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
 54
 55// extended sampling implementation:
 56//
 57// - set logits
 58// - apply the configured sampler chain
 59// - check if the token fits the grammar (if any)
 60// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
 61//
 62// if grammar_first is true, the grammar is applied before the samplers (slower)
 63// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
 64//
 65llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 66
 67// generalized version of common_sampler_sample
 68//
 69// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
 70// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
 71//
 72//      common_sampler_sample_n(gsmpl, ctx, { idx }, {});
 73//
 74// is equivalent to
 75//
 76//      common_sampler_sample(gsmpl, ctx, idx);
 77//      common_sampler_accept(gsmpl, token, true);
 78//
 79// requires: idxs.size() == draft.size() + 1
 80//
 81// returns at least 1 token, up to idxs.size()
 82//
 83std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
 84
 85// assume idxs == [ 0, 1, 2, ..., draft.size() ]
 86std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
 87
 88uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 89
 90// helpers
 91
 92// access the internal list of current candidate tokens
 93// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
 94// the .sorted flag of the result indicates whether the returned candidates are sorted
 95llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
 96
 97// get the last accepted token
 98llama_token common_sampler_last(const struct common_sampler * gsmpl);
 99
100// print the sampler chain into a string
101std::string common_sampler_print(const struct common_sampler * gsmpl);
102
103// get a string representation of the last accepted tokens
104std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
105
106char        common_sampler_type_to_chr(enum common_sampler_type cnstr);
107std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
108
109std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
110std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
111
112llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
113                const char * grammar_kind, const char * grammar_data);
114
115struct common_sampler_deleter {
116    void operator()(common_sampler * s) { common_sampler_free(s); }
117};
118
119typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;