1#include "sampling.h"
2
3#include "common.h"
4#include "log.h"
5
6#include <algorithm>
7#include <cmath>
8#include <cstring>
9#include <unordered_map>
10
11// the ring buffer works similarly to std::deque, but with a fixed capacity
12// TODO: deduplicate with llama-impl.h
13template<typename T>
14struct ring_buffer {
15 ring_buffer(size_t cap) : capacity(cap), data(cap) {}
16
17 T & front() {
18 if (sz == 0) {
19 throw std::runtime_error("ring buffer is empty");
20 }
21 return data[first];
22 }
23
24 const T & front() const {
25 if (sz == 0) {
26 throw std::runtime_error("ring buffer is empty");
27 }
28 return data[first];
29 }
30
31 T & back() {
32 if (sz == 0) {
33 throw std::runtime_error("ring buffer is empty");
34 }
35 return data[pos];
36 }
37
38 const T & back() const {
39 if (sz == 0) {
40 throw std::runtime_error("ring buffer is empty");
41 }
42 return data[pos];
43 }
44
45 void push_back(const T & value) {
46 if (sz == capacity) {
47 // advance the start when buffer is full
48 first = (first + 1) % capacity;
49 } else {
50 sz++;
51 }
52 data[pos] = value;
53 pos = (pos + 1) % capacity;
54 }
55
56 T pop_front() {
57 if (sz == 0) {
58 throw std::runtime_error("ring buffer is empty");
59 }
60 T value = data[first];
61 first = (first + 1) % capacity;
62 sz--;
63 return value;
64 }
65
66 const T & rat(size_t i) const {
67 if (i >= sz) {
68 throw std::runtime_error("ring buffer: index out of bounds");
69 }
70 return data[(first + sz - i - 1) % capacity];
71 }
72
73 std::vector<T> to_vector() const {
74 std::vector<T> result;
75 result.reserve(sz);
76 for (size_t i = 0; i < sz; i++) {
77 result.push_back(data[(first + i) % capacity]);
78 }
79 return result;
80 }
81
82 void clear() {
83 // here only reset the status of the buffer
84 sz = 0;
85 first = 0;
86 pos = 0;
87 }
88
89 bool empty() const {
90 return sz == 0;
91 }
92
93 size_t size() const {
94 return sz;
95 }
96
97 size_t capacity = 0;
98 size_t sz = 0;
99 size_t first = 0;
100 size_t pos = 0;
101 std::vector<T> data;
102};
103
104struct common_sampler {
105 common_params_sampling params;
106
107 struct llama_sampler * grmr;
108 struct llama_sampler * chain;
109
110 ring_buffer<llama_token> prev;
111
112 std::vector<llama_token_data> cur;
113
114 llama_token_data_array cur_p;
115
116 void reset() {
117 prev.clear();
118
119 llama_sampler_reset(chain);
120 }
121
122 void set_logits(struct llama_context * ctx, int idx) {
123 const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
124 const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
125 const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
126
127 const llama_model * model = llama_get_model(ctx);
128 const llama_vocab * vocab = llama_model_get_vocab(model);
129
130 const int n_vocab = llama_vocab_n_tokens(vocab);
131
132 if (sampled_probs) {
133 const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
134 cur.resize(sampled_probs_count);
135 for (uint32_t i = 0; i < sampled_probs_count; ++i) {
136 cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
137 }
138 } else if (sampled_logits) {
139 const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
140 cur.resize(sampled_logits_count);
141 for (uint32_t i = 0; i < sampled_logits_count; i++) {
142 cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
143 }
144 } else {
145 const auto * logits = llama_get_logits_ith(ctx, idx);
146 GGML_ASSERT(logits != nullptr);
147 cur.resize(n_vocab);
148 for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
149 cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
150 }
151 }
152
153 cur_p = { cur.data(), cur.size(), -1, false };
154 }
155
156 common_time_meas tm() {
157 return common_time_meas(t_total_us, params.no_perf);
158 }
159
160 mutable int64_t t_total_us = 0;
161};
162
163std::string common_params_sampling::print() const {
164 char result[1024];
165
166 snprintf(result, sizeof(result),
167 "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
168 "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
169 "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
170 "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
171 penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
172 dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
173 top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
174 mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
175
176 return std::string(result);
177}
178
179struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
180 const llama_vocab * vocab = llama_model_get_vocab(model);
181
182 llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
183
184 lparams.no_perf = params.no_perf;
185
186 llama_sampler * grmr = nullptr;
187 llama_sampler * chain = llama_sampler_chain_init(lparams);
188
189 std::vector<llama_sampler *> samplers;
190
191 if (params.grammar.compare(0, 11, "%llguidance") == 0) {
192#ifdef LLAMA_USE_LLGUIDANCE
193 grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
194#else
195 GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
196#endif // LLAMA_USE_LLGUIDANCE
197 } else {
198 std::vector<std::string> trigger_patterns;
199 std::vector<llama_token> trigger_tokens;
200 for (const auto & trigger : params.grammar_triggers) {
201 switch (trigger.type) {
202 case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
203 {
204 const auto & word = trigger.value;
205 trigger_patterns.push_back(regex_escape(word));
206 break;
207 }
208 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
209 {
210 trigger_patterns.push_back(trigger.value);
211 break;
212 }
213 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
214 {
215 const auto & pattern = trigger.value;
216 std::string anchored = "^$";
217 if (!pattern.empty()) {
218 anchored = (pattern.front() != '^' ? "^" : "")
219 + pattern
220 + (pattern.back() != '$' ? "$" : "");
221 }
222 trigger_patterns.push_back(anchored);
223 break;
224 }
225 case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
226 {
227 const auto token = trigger.token;
228 trigger_tokens.push_back(token);
229 break;
230 }
231 default:
232 GGML_ASSERT(false && "unknown trigger type");
233 }
234 }
235
236 std::vector<const char *> trigger_patterns_c;
237 trigger_patterns_c.reserve(trigger_patterns.size());
238 for (const auto & regex : trigger_patterns) {
239 trigger_patterns_c.push_back(regex.c_str());
240 }
241
242 if (!params.grammar.empty()) {
243 if (params.grammar_lazy) {
244 grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
245 trigger_patterns_c.data(), trigger_patterns_c.size(),
246 trigger_tokens.data(), trigger_tokens.size());
247 } else {
248 grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
249 }
250 }
251 }
252
253 if (params.has_logit_bias()) {
254 samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
255 }
256
257 if (params.mirostat == 0) {
258
259 bool use_adaptive_p = false; // see below
260
261 for (const auto & cnstr : params.samplers) {
262 switch (cnstr) {
263 case COMMON_SAMPLER_TYPE_DRY:
264 {
265 std::vector<const char *> c_breakers;
266 c_breakers.reserve(params.dry_sequence_breakers.size());
267 for (const auto & str : params.dry_sequence_breakers) {
268 c_breakers.push_back(str.c_str());
269 }
270 samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
271 }
272 break;
273 case COMMON_SAMPLER_TYPE_TOP_K:
274 samplers.push_back(llama_sampler_init_top_k(params.top_k));
275 break;
276 case COMMON_SAMPLER_TYPE_TOP_P:
277 samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
278 break;
279 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
280 samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
281 break;
282 case COMMON_SAMPLER_TYPE_MIN_P:
283 samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
284 break;
285 case COMMON_SAMPLER_TYPE_XTC:
286 samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
287 break;
288 case COMMON_SAMPLER_TYPE_TYPICAL_P:
289 samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
290 break;
291 case COMMON_SAMPLER_TYPE_TEMPERATURE:
292 samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
293 break;
294 case COMMON_SAMPLER_TYPE_INFILL:
295 samplers.push_back(llama_sampler_init_infill(vocab));
296 break;
297 case COMMON_SAMPLER_TYPE_PENALTIES:
298 samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
299 break;
300 case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
301 // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
302 // a single token, so we will add `dist` at the end of the chain by default,
303 // unless the user specifically included `adaptive-p`. we set this flag here
304 // so we know to add the sampler at the very end.
305 use_adaptive_p = true;
306 break;
307 default:
308 GGML_ASSERT(false && "unknown sampler type");
309 }
310 }
311 if (use_adaptive_p) {
312 // only if user explicitly included adaptive-p sampler
313 samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
314 } else {
315 // default: sample from distribution
316 samplers.push_back(llama_sampler_init_dist(params.seed));
317 }
318 } else if (params.mirostat == 1) {
319 samplers.push_back(llama_sampler_init_temp(params.temp));
320 samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
321 } else if (params.mirostat == 2) {
322 samplers.push_back(llama_sampler_init_temp(params.temp));
323 samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
324 } else {
325 GGML_ASSERT(false && "unknown mirostat version");
326 }
327
328 for (auto * smpl : samplers) {
329 llama_sampler_chain_add(chain, smpl);
330 }
331
332 if (grmr && params.backend_sampling) {
333 LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
334
335 params.backend_sampling = false;
336 }
337
338 auto * result = new common_sampler {
339 /* .params = */ params,
340 /* .grmr = */ grmr,
341 /* .chain = */ chain,
342 /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
343 /* .cur = */ {},
344 /* .cur_p = */ {},
345 };
346
347 return result;
348}
349
350void common_sampler_free(struct common_sampler * gsmpl) {
351 if (!gsmpl) {
352 return;
353 }
354
355 llama_sampler_free(gsmpl->grmr);
356 llama_sampler_free(gsmpl->chain);
357
358 delete gsmpl;
359}
360
361void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
362 if (!gsmpl) {
363 return;
364 }
365
366 const auto tm = gsmpl->tm();
367
368 if (gsmpl->grmr && accept_grammar) {
369 llama_sampler_accept(gsmpl->grmr, token);
370 }
371
372 llama_sampler_accept(gsmpl->chain, token);
373
374 gsmpl->prev.push_back(token);
375}
376
377void common_sampler_reset(struct common_sampler * gsmpl) {
378 if (!gsmpl) {
379 return;
380 }
381
382 gsmpl->reset();
383}
384
385struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
386 return new common_sampler {
387 /* .params = */ gsmpl->params,
388 /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
389 /* .chain = */ llama_sampler_clone(gsmpl->chain),
390 /* .prev = */ gsmpl->prev,
391 /* .cur = */ gsmpl->cur,
392 /* .cur_p = */ gsmpl->cur_p,
393 };
394}
395
396void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
397 // TODO: measure grammar performance
398
399 const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
400
401 llama_perf_sampler_data data_smpl;
402 llama_perf_context_data data_ctx;
403
404 memset(&data_smpl, 0, sizeof(data_smpl));
405 memset(&data_ctx, 0, sizeof(data_ctx));
406
407 if (gsmpl) {
408 auto & data = data_smpl;
409
410 data = llama_perf_sampler(gsmpl->chain);
411
412 // note: the sampling time includes the samplers time + extra time spent in common/sampling
413 LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
414 LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
415 }
416
417 if (ctx) {
418 auto & data = data_ctx;
419
420 data = llama_perf_context(ctx);
421
422 const double t_end_ms = 1e-3 * ggml_time_us();
423
424 const double t_total_ms = t_end_ms - data.t_start_ms;
425 const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
426 const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
427
428 LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
429 LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
430 __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
431 LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
432 __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
433 LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
434 LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
435 LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
436
437 llama_memory_breakdown_print(ctx);
438 }
439}
440
441struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
442 if (!gsmpl) {
443 return nullptr;
444 }
445
446 return gsmpl->chain;
447}
448
449llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
450 llama_synchronize(ctx);
451
452 // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
453 const auto tm = gsmpl->tm();
454
455 llama_token id = LLAMA_TOKEN_NULL;
456
457 auto & grmr = gsmpl->grmr;
458 auto & chain = gsmpl->chain;
459 auto & cur_p = gsmpl->cur_p; // initialized by set_logits
460
461 // Check if a backend sampler has already sampled a token in which case we
462 // return that token id directly.
463 {
464 id = llama_get_sampled_token_ith(ctx, idx);
465
466 if (id != LLAMA_TOKEN_NULL) {
467 LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
468
469 GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
470
471 // TODO: simplify
472 gsmpl->cur.resize(1);
473 gsmpl->cur[0] = { id, 0.0f, 1.0f };
474 cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
475
476 return id;
477 }
478 }
479
480 gsmpl->set_logits(ctx, idx);
481
482 if (grammar_first) {
483 llama_sampler_apply(grmr, &cur_p);
484 }
485
486 llama_sampler_apply(chain, &cur_p);
487
488 id = cur_p.data[cur_p.selected].id;
489
490 if (grammar_first) {
491 return id;
492 }
493
494 // check if it the sampled token fits the grammar (grammar-based rejection sampling)
495 {
496 llama_token_data single_token_data = { id, 1.0f, 0.0f };
497 llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
498
499 llama_sampler_apply(grmr, &single_token_data_array);
500
501 const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
502 if (is_valid) {
503 return id;
504 }
505 }
506
507 // resampling:
508 // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
509 gsmpl->set_logits(ctx, idx);
510
511 llama_sampler_apply(grmr, &cur_p);
512 llama_sampler_apply(chain, &cur_p);
513
514 GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
515
516 id = cur_p.data[cur_p.selected].id;
517
518 return id;
519}
520
521std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
522 GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
523
524 std::vector<llama_token> result;
525 result.reserve(idxs.size());
526
527 size_t i = 0;
528 for (; i < draft.size(); i++) {
529 const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
530
531 common_sampler_accept(gsmpl, id, true);
532
533 result.push_back(id);
534
535 if (draft[i] != id) {
536 break;
537 }
538 }
539
540 if (i == draft.size()) {
541 const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
542
543 common_sampler_accept(gsmpl, id, true);
544
545 result.push_back(id);
546 }
547
548 return result;
549}
550
551std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
552 std::vector<int> idxs(draft.size() + 1);
553 for (size_t i = 0; i < idxs.size(); ++i) {
554 idxs[i] = i;
555 }
556
557 return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
558}
559
560uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
561 return llama_sampler_get_seed(gsmpl->chain);
562}
563
564// helpers
565
566llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
567 const auto tm = gsmpl->tm();
568
569 auto * res = &gsmpl->cur_p;
570
571 if (do_sort && !res->sorted) {
572 // remember the selected token before sorting
573 const llama_token id = res->data[res->selected].id;
574
575 std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
576 return a.p > b.p;
577 });
578
579 // restore the selected token after sorting
580 for (size_t i = 0; i < res->size; ++i) {
581 if (res->data[i].id == id) {
582 res->selected = i;
583 break;
584 }
585 }
586
587 res->sorted = true;
588 }
589
590 return res;
591}
592
593llama_token common_sampler_last(const struct common_sampler * gsmpl) {
594 return gsmpl->prev.rat(0);
595}
596
597std::string common_sampler_print(const struct common_sampler * gsmpl) {
598 std::string result = "logits ";
599
600 for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
601 const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
602 result += std::string("-> ");
603 result += std::string(llama_sampler_name(smpl)) + " ";
604 }
605
606 return result;
607}
608
609std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
610 n = std::min(n, (int) gsmpl->prev.size());
611
612 if (n <= 0) {
613 return "";
614 }
615
616 std::string result;
617 result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
618
619 for (int i = n - 1; i >= 0; i--) {
620 const llama_token id = gsmpl->prev.rat(i);
621
622 GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
623
624 result += common_token_to_piece(ctx_main, id);
625 }
626
627 return result;
628}
629
630char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
631 switch (cnstr) {
632 case COMMON_SAMPLER_TYPE_DRY: return 'd';
633 case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
634 case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
635 case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
636 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
637 case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
638 case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
639 case COMMON_SAMPLER_TYPE_XTC: return 'x';
640 case COMMON_SAMPLER_TYPE_INFILL: return 'i';
641 case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
642 case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
643 default : return '?';
644 }
645}
646
647std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
648 switch (cnstr) {
649 case COMMON_SAMPLER_TYPE_DRY: return "dry";
650 case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
651 case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
652 case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
653 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
654 case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
655 case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
656 case COMMON_SAMPLER_TYPE_XTC: return "xtc";
657 case COMMON_SAMPLER_TYPE_INFILL: return "infill";
658 case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
659 case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
660 default : return "";
661 }
662}
663
664std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
665 std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
666 { "dry", COMMON_SAMPLER_TYPE_DRY },
667 { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
668 { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
669 { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
670 { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
671 { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
672 { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
673 { "xtc", COMMON_SAMPLER_TYPE_XTC },
674 { "infill", COMMON_SAMPLER_TYPE_INFILL },
675 { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
676 { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
677 };
678
679 // since samplers names are written multiple ways
680 // make it ready for both system names and input names
681 std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
682 { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
683 { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
684 { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
685 { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
686 { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
687 { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
688 { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
689 { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
690 { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
691 { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
692 { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
693 };
694
695 std::vector<common_sampler_type> samplers;
696 samplers.reserve(names.size());
697
698 for (const auto & name : names) {
699 auto sampler = sampler_canonical_name_map.find(name);
700 if (sampler != sampler_canonical_name_map.end()) {
701 samplers.push_back(sampler->second);
702 continue;
703 }
704 if (allow_alt_names) {
705 sampler = sampler_alt_name_map.find(name);
706 if (sampler != sampler_alt_name_map.end()) {
707 samplers.push_back(sampler->second);
708 continue;
709 }
710 }
711 LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
712 }
713
714 return samplers;
715}
716
717std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
718 std::unordered_map<char, common_sampler_type> sampler_name_map = {
719 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
720 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
721 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
722 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
723 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
724 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
725 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
726 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
727 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
728 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
729 { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
730 };
731
732 std::vector<common_sampler_type> samplers;
733 samplers.reserve(chars.size());
734
735 for (const auto & c : chars) {
736 const auto sampler = sampler_name_map.find(c);
737 if (sampler != sampler_name_map.end()) {
738 samplers.push_back(sampler->second);
739 } else {
740 LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
741 }
742 }
743
744 return samplers;
745}