1#include "llama-sampler.h"
2
3#include "llama-impl.h"
4#include "llama-vocab.h"
5#include "llama-grammar.h"
6
7#include "ggml-cpp.h"
8
9#include <array>
10#include <algorithm>
11#include <cassert>
12#include <cfloat>
13#include <chrono>
14#include <cmath>
15#include <cstdlib>
16#include <cstring>
17#include <ctime>
18#include <numeric>
19#include <random>
20#include <unordered_map>
21#include <stdexcept>
22
23// the ring buffer works similarly to std::deque, but with a fixed capacity
24template<typename T>
25struct ring_buffer {
26 ring_buffer(size_t cap) : capacity(cap), data(cap) {}
27
28 T & front() {
29 if (sz == 0) {
30 throw std::runtime_error("ring buffer is empty");
31 }
32 return data[first];
33 }
34
35 const T & front() const {
36 if (sz == 0) {
37 throw std::runtime_error("ring buffer is empty");
38 }
39 return data[first];
40 }
41
42 T & back() {
43 if (sz == 0) {
44 throw std::runtime_error("ring buffer is empty");
45 }
46 return data[pos];
47 }
48
49 const T & back() const {
50 if (sz == 0) {
51 throw std::runtime_error("ring buffer is empty");
52 }
53 return data[pos];
54 }
55
56 void push_back(const T & value) {
57 if (capacity == 0) {
58 throw std::runtime_error("ring buffer: capacity is zero");
59 }
60
61 if (sz == capacity) {
62 // advance the start when buffer is full
63 first = (first + 1) % capacity;
64 } else {
65 sz++;
66 }
67 data[pos] = value;
68 pos = (pos + 1) % capacity;
69 }
70
71 T pop_front() {
72 if (sz == 0) {
73 throw std::runtime_error("ring buffer is empty");
74 }
75 T value = data[first];
76 first = (first + 1) % capacity;
77 sz--;
78 return value;
79 }
80
81 //T & operator[](size_t i) {
82 // if (i >= sz) {
83 // throw std::runtime_error("ring buffer: index out of bounds");
84 // }
85 // return data[(first + i) % capacity];
86 //}
87
88 //const T & at(size_t i) const {
89 // if (i >= sz) {
90 // throw std::runtime_error("ring buffer: index out of bounds");
91 // }
92 // return data[(first + i) % capacity];
93 //}
94
95 const T & rat(size_t i) const {
96 if (i >= sz) {
97 throw std::runtime_error("ring buffer: index out of bounds");
98 }
99 return data[(first + sz - i - 1) % capacity];
100 }
101
102 std::vector<T> to_vector() const {
103 std::vector<T> result;
104 result.reserve(sz);
105 for (size_t i = 0; i < sz; i++) {
106 result.push_back(data[(first + i) % capacity]);
107 }
108 return result;
109 }
110
111 void clear() {
112 // here only reset the status of the buffer
113 sz = 0;
114 first = 0;
115 pos = 0;
116 }
117
118 bool empty() const {
119 return sz == 0;
120 }
121
122 size_t size() const {
123 return sz;
124 }
125
126 size_t capacity = 0;
127 size_t sz = 0;
128 size_t first = 0;
129 size_t pos = 0;
130
131 std::vector<T> data;
132};
133
134// writes result in res, does not mutate cur
135static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
136 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
137 return a.logit > b.logit;
138 };
139
140 constexpr int nbuckets = 128;
141 constexpr float bucket_low = -10.0f;
142 constexpr float bucket_high = 10.0f;
143 constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
144 constexpr float bucket_inter = -bucket_low * bucket_scale;
145
146 std::vector<int> bucket_idx;
147 std::vector<int> histo(nbuckets, 0);
148
149 std::vector<llama_token_data*> bucket_ptrs;
150
151 bucket_idx.reserve(cur.size);
152
153 for (int i = 0; i < (int)cur.size; ++i) {
154 const float val = cur.data[i].logit;
155 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
156 ib = std::max(0, std::min(nbuckets - 1, ib));
157 bucket_idx.push_back(ib);
158 ++histo[ib];
159 }
160 int nhave = 0;
161 int ib = nbuckets - 1;
162 for ( ; ib >= 0; --ib) {
163 nhave += histo[ib];
164 if (nhave >= npartial) {
165 break;
166 }
167 }
168 res.resize(nhave);
169 auto * ptr = res.data();
170 bucket_ptrs.reserve(nbuckets - ib);
171 for (int j = nbuckets - 1; j >= ib; --j) {
172 bucket_ptrs.push_back(ptr);
173 ptr += histo[j];
174 }
175 for (int i = 0; i < (int)cur.size; ++i) {
176 int j = bucket_idx[i];
177 if (j >= ib) {
178 *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
179 }
180 }
181
182 ptr = res.data();
183 int ndone = 0;
184 for (int j = nbuckets - 1; j > ib; --j) {
185 std::sort(ptr, ptr + histo[j], comp);
186 ptr += histo[j];
187 ndone += histo[j];
188 }
189 std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
190}
191
192// reduces the size of cur_p to npartial, keeping only the top npartial elements
193static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
194 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
195 return a.logit > b.logit;
196 };
197
198 if (npartial <= 128) {
199 std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
200
201 cur_p->size = npartial;
202 cur_p->sorted = true;
203
204 return;
205 }
206
207 std::vector<llama_token_data> tmp;
208
209 llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
210
211 std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
212
213 cur_p->size = npartial;
214 cur_p->sorted = true;
215}
216
217static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
218 // iterator for the probabilities
219#ifdef __GNUC__
220 #pragma GCC diagnostic push
221 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
222#endif
223
224 struct probs_iterator {
225 typedef std::input_iterator_tag iterator_category;
226 typedef float value_type;
227 typedef float * pointer;
228 typedef float & reference;
229 typedef ptrdiff_t difference_type;
230
231 const llama_token_data * data;
232
233 bool operator==(const probs_iterator & other) const { return data == other.data; }
234 bool operator!=(const probs_iterator & other) const { return data != other.data; }
235 const float & operator*() const { return data->p; }
236 probs_iterator & operator++() { ++data; return *this; }
237 probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
238 };
239
240#ifdef __GNUC__
241 #pragma GCC diagnostic pop
242#endif
243
244 std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
245
246 return dist(rng);
247}
248
249/*
250static void llama_log_softmax(float * array, size_t size) {
251 float max_l = *std::max_element(array, array + size);
252 float sum = 0.f;
253 for (size_t i = 0; i < size; ++i) {
254 float p = expf(array[i] - max_l);
255 sum += p;
256 array[i] = p;
257 }
258
259 for (size_t i = 0; i < size; ++i) {
260 array[i] = logf(array[i] / sum);
261 }
262}
263*/
264
265static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
266 if (temp <= 0.0f) {
267 // find the token with the highest logit and set the rest to -inf
268 size_t max_i = 0;
269 float max_l = cur_p->data[0].logit;
270
271 for (size_t i = 1; i < cur_p->size; ++i) {
272 if (cur_p->data[i ].logit > max_l) {
273 cur_p->data[max_i].logit = -INFINITY;
274 max_i = i;
275 max_l = cur_p->data[i].logit;
276 } else {
277 cur_p->data[i].logit = -INFINITY;
278 }
279 }
280
281 return;
282 }
283
284 for (size_t i = 0; i < cur_p->size; ++i) {
285 cur_p->data[i].logit /= temp;
286 }
287}
288
289static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
290 GGML_ASSERT(cur_p->size > 0);
291
292 // Sort the logits in descending order if requested
293 if (do_sort && !cur_p->sorted) {
294 llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
295 }
296
297 float max_l = cur_p->data[0].logit;
298 if (!cur_p->sorted) {
299 for (size_t i = 1; i < cur_p->size; ++i) {
300 max_l = std::max(max_l, cur_p->data[i].logit);
301 }
302 }
303
304 float cum_sum = 0.0f;
305
306 for (size_t i = 0; i < cur_p->size; ++i) {
307 float p = expf(cur_p->data[i].logit - max_l);
308 cur_p->data[i].p = p;
309 cum_sum += p;
310 }
311
312 for (size_t i = 0; i < cur_p->size; ++i) {
313 cur_p->data[i].p /= cum_sum;
314 }
315}
316
317static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
318 // if (k >= (int32_t)cur_p->size) {
319 // return;
320 // }
321
322 if (k <= 0) {
323 return;
324 }
325
326 k = std::min(k, (int) cur_p->size);
327
328 // Sort scores in descending order
329 if (!cur_p->sorted) {
330 llama_token_data_array_partial_sort_inplace(cur_p, k);
331 }
332
333 cur_p->size = k;
334}
335
336static uint32_t get_rng_seed(uint32_t seed) {
337 if (seed == LLAMA_DEFAULT_SEED) {
338 // use system clock if std::random_device is not a true RNG
339 static bool is_rd_prng = std::random_device().entropy() == 0;
340 if (is_rd_prng) {
341 return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
342 }
343 std::random_device rd;
344 return rd();
345 }
346 return seed;
347}
348
349// llama_sampler API
350
351struct llama_sampler * llama_sampler_init(
352 struct llama_sampler_i * iface,
353 llama_sampler_context_t ctx) {
354 return new llama_sampler {
355 /* .iface = */ iface,
356 /* .ctx = */ ctx,
357 };
358}
359
360const char * llama_sampler_name(const struct llama_sampler * smpl) {
361 if (!smpl->iface) {
362 return "(null)";
363 }
364
365 return smpl->iface->name(smpl);
366}
367
368void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369 if (!smpl) {
370 return;
371 }
372
373 if (smpl->iface->accept) {
374 smpl->iface->accept(smpl, token);
375 }
376}
377
378void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379 if (!smpl) {
380 return;
381 }
382
383 GGML_ASSERT(smpl->iface->apply);
384 smpl->iface->apply(smpl, cur_p);
385}
386
387void llama_sampler_reset(struct llama_sampler * smpl) {
388 if (!smpl) {
389 return;
390 }
391
392 if (smpl->iface->reset) {
393 smpl->iface->reset(smpl);
394 }
395}
396
397struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398 if (!smpl) {
399 return nullptr;
400 }
401
402 if (smpl->iface->clone) {
403 return smpl->iface->clone(smpl);
404 }
405
406 if (smpl->ctx == nullptr) {
407 return llama_sampler_init(
408 /* .iface = */ smpl->iface,
409 /* .ctx = */ nullptr
410 );
411 }
412
413 GGML_ABORT("the sampler does not support cloning");
414}
415
416void llama_sampler_free(struct llama_sampler * smpl) {
417 if (smpl == nullptr) {
418 return;
419 }
420
421 if (smpl->iface->free) {
422 smpl->iface->free(smpl);
423 }
424
425 delete smpl;
426}
427
428// empty sampler
429
430struct llama_sampler_empty {
431 const char * name;
432};
433
434static struct llama_sampler * llama_sampler_init_empty(const char * name);
435
436static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
437 auto * ctx = (llama_sampler_empty *) smpl->ctx;
438 return ctx->name;
439}
440
441static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
442 GGML_UNUSED(smpl);
443 GGML_UNUSED(token);
444}
445
446static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
447 GGML_UNUSED(smpl);
448 GGML_UNUSED(cur_p);
449}
450
451static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
452 GGML_UNUSED(smpl);
453}
454
455static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
456 auto * ctx = (llama_sampler_empty *) smpl->ctx;
457 return llama_sampler_init_empty(ctx->name);
458}
459
460static void llama_sampler_empty_free(struct llama_sampler * smpl) {
461 delete (llama_sampler_empty *) smpl->ctx;
462}
463
464static bool llama_sampler_empty_backend_init(
465 struct llama_sampler * smpl,
466 ggml_backend_buffer_type_t buft) {
467 GGML_UNUSED(smpl);
468 GGML_UNUSED(buft);
469
470 return true;
471}
472
473static void llama_sampler_empty_backend_accept(
474 struct llama_sampler * smpl,
475 ggml_context * ctx,
476 ggml_cgraph * gf,
477 struct ggml_tensor * selected_token) {
478 GGML_UNUSED(smpl);
479 GGML_UNUSED(ctx);
480 GGML_UNUSED(gf);
481 GGML_UNUSED(selected_token);
482}
483
484static void llama_sampler_empty_backend_apply(
485 struct llama_sampler * smpl,
486 struct ggml_context * ctx,
487 struct ggml_cgraph * gf,
488 struct llama_sampler_data * data) {
489 GGML_UNUSED(smpl);
490 GGML_UNUSED(ctx);
491 GGML_UNUSED(gf);
492 GGML_UNUSED(data);
493}
494
495static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
496 GGML_UNUSED(smpl);
497}
498
499static struct llama_sampler_i llama_sampler_empty_i = {
500 /* .name = */ llama_sampler_empty_name,
501 /* .accept = */ llama_sampler_empty_accept,
502 /* .apply = */ llama_sampler_empty_apply,
503 /* .reset = */ llama_sampler_empty_reset,
504 /* .clone = */ llama_sampler_empty_clone,
505 /* .free = */ llama_sampler_empty_free,
506 /* .backend_init = */ llama_sampler_empty_backend_init,
507 /* .backend_accept = */ llama_sampler_empty_backend_accept,
508 /* .backend_apply = */ llama_sampler_empty_backend_apply,
509 /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
510};
511
512struct llama_sampler * llama_sampler_init_empty(const char * name) {
513 return llama_sampler_init(
514 /* .iface = */ &llama_sampler_empty_i,
515 /* .ctx = */ new llama_sampler_empty {
516 /* .name = */ name,
517 }
518 );
519}
520
521// common backend sampler functionality
522//
523// +name : means that the sampler is support and will run on the backend
524// -name : means that a ggml operator is not supported by the backend
525//
526struct llama_sampler_backend {
527 llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
528
529 const char * get_name() {
530 if (!is_init) {
531 return name.c_str();
532 }
533
534 if (support) {
535 name_ext = "+" + name;
536 } else {
537 name_ext = "-" + name;
538 }
539
540 return name_ext.c_str();
541 }
542
543 void init(bool support) {
544 GGML_ASSERT(this->is_init == false);
545
546 this->is_init = true;
547 this->support = support;
548 }
549
550private:
551 std::string name;
552 std::string name_ext;
553
554 bool is_init;
555 bool support;
556};
557
558// check if all ggml ops used by the sampler are supported by the backend
559static bool llama_sampler_backend_support(
560 llama_sampler * smpl,
561 ggml_backend_buffer_type_t buft) {
562 auto * device = ggml_backend_buft_get_device(buft);
563 if (!device) {
564 // CPU backend always supported
565 return true;
566 }
567
568 ggml_init_params params = {
569 /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
570 /*.mem_buffer =*/ NULL,
571 /*.no_alloc =*/ true,
572 };
573
574 ggml_context_ptr ctx_ptr { ggml_init(params) };
575 if (!ctx_ptr) {
576 throw std::runtime_error(format("failed to create ggml context"));
577 }
578
579 ggml_context * ctx = ctx_ptr.get();
580
581 const int64_t n = 1024*1024;
582
583 llama_sampler_data data = {
584 /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
585 /*.probs = */ nullptr,
586 /*.sampled = */ nullptr,
587 /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
588 };
589
590 ggml_cgraph * gf = ggml_new_graph(ctx);
591
592 smpl->iface->backend_apply(smpl, ctx, gf, &data);
593
594 if (data.logits) {
595 ggml_build_forward_expand(gf, data.logits);
596 }
597
598 if (data.probs) {
599 ggml_build_forward_expand(gf, data.probs);
600 }
601
602 if (data.sampled) {
603 ggml_build_forward_expand(gf, data.sampled);
604 }
605
606 if (data.candidates) {
607 ggml_build_forward_expand(gf, data.candidates);
608 }
609
610 for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
611 struct ggml_tensor * op = ggml_graph_node(gf, i);
612
613 if (!ggml_backend_dev_supports_op(device, op)) {
614 LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
615 __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
616
617 return false;
618 }
619 }
620
621 return true;
622}
623
624// sampler chain
625
626static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
627 return "chain";
628}
629
630static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
631 auto * chain = (llama_sampler_chain *) smpl->ctx;
632
633 time_meas tm(chain->t_sample_us, chain->params.no_perf);
634
635 for (auto & smpl : chain->samplers) {
636 llama_sampler_accept(smpl.ptr, token);
637 }
638
639 chain->n_sample++;
640}
641
642static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
643 auto * chain = (llama_sampler_chain *) smpl->ctx;
644
645 time_meas tm(chain->t_sample_us, chain->params.no_perf);
646
647 bool is_backend = chain->is_init;
648
649 for (auto & smpl : chain->samplers) {
650 if (is_backend && smpl.is_backend) {
651 continue;
652 }
653
654 is_backend = false;
655
656 if (smpl.ptr->iface->apply == nullptr) {
657 continue;
658 }
659
660 llama_sampler_apply(smpl.ptr, cur_p);
661 }
662}
663
664static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
665 auto * chain = (llama_sampler_chain *) smpl->ctx;
666
667 for (auto & smpl : chain->samplers) {
668 llama_sampler_reset(smpl.ptr);
669 }
670}
671
672static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
673 const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
674
675 auto * result = llama_sampler_chain_init(chain_src->params);
676
677 for (const auto & smpl : chain_src->samplers) {
678 llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
679 }
680
681 return result;
682}
683
684static void llama_sampler_chain_free(struct llama_sampler * smpl) {
685 auto * chain = (llama_sampler_chain *) smpl->ctx;
686
687 for (auto & smpl : chain->samplers) {
688 llama_sampler_free(smpl.ptr);
689 }
690
691 delete chain;
692}
693
694static bool llama_sampler_chain_backend_init(
695 struct llama_sampler * smpl,
696 ggml_backend_buffer_type_t buft) {
697 auto * chain = (llama_sampler_chain *) smpl->ctx;
698
699 GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
700
701 chain->is_init = true;
702
703 bool res = true;
704
705 for (auto & smpl : chain->samplers) {
706 bool res_cur = true;
707
708 // to be able to run a sampler on the backend, it has to:
709 // - have the .backend_init() API implemented
710 // - return true during .backend_init()
711 if (smpl.ptr->iface->backend_init) {
712 if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
713 res_cur = false;
714 }
715 } else {
716 res_cur = false;
717 }
718
719 smpl.is_backend = res_cur;
720
721 res = res && res_cur;
722 }
723
724 return res;
725}
726
727static void llama_sampler_chain_backend_accept(
728 struct llama_sampler * smpl,
729 ggml_context * ctx,
730 ggml_cgraph * gf,
731 struct ggml_tensor * selected_token) {
732 auto * chain = (llama_sampler_chain *) smpl->ctx;
733
734 for (auto & smpl : chain->samplers) {
735 if (!smpl.is_backend) {
736 break;
737 }
738
739 if (smpl.ptr->iface->backend_accept) {
740 smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
741 }
742 }
743}
744
745static void llama_sampler_chain_backend_apply(
746 struct llama_sampler * smpl,
747 struct ggml_context * ctx,
748 struct ggml_cgraph * gf,
749 struct llama_sampler_data * data) {
750 auto * chain = (llama_sampler_chain *) smpl->ctx;
751
752 GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
753
754 for (auto & smpl : chain->samplers) {
755 if (!smpl.is_backend) {
756 break;
757 }
758
759 if (smpl.ptr->iface->backend_apply) {
760 smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
761 }
762 }
763}
764
765static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
766 auto * chain = (llama_sampler_chain *) smpl->ctx;
767
768 for (auto & smpl : chain->samplers) {
769 if (!smpl.is_backend) {
770 break;
771 }
772
773 if (smpl.ptr->iface->backend_set_input) {
774 smpl.ptr->iface->backend_set_input(smpl.ptr);
775 }
776 }
777}
778
779static struct llama_sampler_i llama_sampler_chain_i = {
780 /* .name = */ llama_sampler_chain_name,
781 /* .accept = */ llama_sampler_chain_accept,
782 /* .apply = */ llama_sampler_chain_apply,
783 /* .reset = */ llama_sampler_chain_reset,
784 /* .clone = */ llama_sampler_chain_clone,
785 /* .free = */ llama_sampler_chain_free,
786 /* .backend_init = */ llama_sampler_chain_backend_init,
787 /* .backend_accept = */ llama_sampler_chain_backend_accept,
788 /* .backend_apply = */ llama_sampler_chain_backend_apply,
789 /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
790};
791
792struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
793 return llama_sampler_init(
794 /* .iface = */ &llama_sampler_chain_i,
795 /* .ctx = */ new llama_sampler_chain {
796 /* .params = */ params,
797 /* .is_init = */ false,
798 /* .samplers = */ {},
799 /* .cur = */ {},
800 /* .t_sample_us = */ 0,
801 /* .n_sample = */ 0,
802 }
803 );
804}
805
806llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
807 const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
808 const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
809 const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
810 const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
811
812 // If a backend sampler has already sampled a token, return it.
813 if (sampled_token != LLAMA_TOKEN_NULL) {
814 LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
815 return sampled_token;
816 }
817
818 const llama_model * model = llama_get_model(ctx);
819 const llama_vocab * vocab = llama_model_get_vocab(model);
820
821 const int n_vocab = llama_vocab_n_tokens(vocab);
822
823 // use pre-allocated buffer from chain if available, otherwise allocate locally
824 std::vector<llama_token_data> * cur_ptr;
825 std::vector<llama_token_data> cur_local;
826
827 if (smpl->iface == &llama_sampler_chain_i) {
828 auto * chain = (llama_sampler_chain *) smpl->ctx;
829 cur_ptr = &chain->cur;
830 } else {
831 cur_ptr = &cur_local;
832 }
833
834 auto & cur = *cur_ptr;
835
836 if (sampled_probs) {
837 const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
838 cur.resize(sampled_probs_count);
839 for (uint32_t i = 0; i < sampled_probs_count; ++i) {
840 cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
841 }
842 } else if (sampled_logits) {
843 const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
844 cur.resize(sampled_logits_count);
845 for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
846 cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
847 }
848 } else {
849 const auto * logits = llama_get_logits_ith(ctx, idx);
850 GGML_ASSERT(logits != nullptr);
851 cur.resize(n_vocab);
852 for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
853 cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
854 }
855 }
856
857 llama_token_data_array cur_p = {
858 /* .data = */ cur.data(),
859 /* .size = */ cur.size(),
860 /* .selected = */ -1,
861 /* .sorted = */ false,
862 };
863
864 llama_sampler_apply(smpl, &cur_p);
865
866 GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
867
868 auto token = cur_p.data[cur_p.selected].id;
869
870 llama_sampler_accept(smpl, token);
871
872 return token;
873}
874
875
876void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
877 auto * p = (llama_sampler_chain *) chain->ctx;
878 p->samplers.push_back({
879 /* .is_backend = */ false,
880 /* .ptr = */ smpl,
881 });
882}
883
884struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
885 if (chain == nullptr) {
886 return nullptr;
887 }
888
889 if (chain->iface != &llama_sampler_chain_i) {
890 return nullptr;
891 }
892
893 if (i == -1) {
894 return chain;
895 }
896
897 const auto * p = (const llama_sampler_chain *) chain->ctx;
898
899 if (i < 0 || (size_t) i >= p->samplers.size()) {
900 return nullptr;
901 }
902
903 return p->samplers[i].ptr;
904}
905
906struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
907 auto * p = (llama_sampler_chain *) chain->ctx;
908
909 if (i < 0 || (size_t) i >= p->samplers.size()) {
910 return nullptr;
911 }
912
913 auto * result = p->samplers[i].ptr;
914 p->samplers.erase(p->samplers.begin() + i);
915
916 return result;
917}
918
919int llama_sampler_chain_n(const struct llama_sampler * chain) {
920 const auto * p = (const llama_sampler_chain *) chain->ctx;
921
922 return p->samplers.size();
923}
924
925//
926// samplers
927//
928
929// greedy
930
931struct llama_sampler_greedy : public llama_sampler_backend {
932};
933
934static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
935 auto * sctx = (llama_sampler_greedy *) smpl->ctx;
936 return sctx->get_name();
937}
938
939static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
940 auto * ctx = (llama_sampler_greedy *) smpl->ctx;
941 GGML_UNUSED(ctx);
942}
943
944static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
945 const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
946 auto * result = llama_sampler_init_greedy();
947
948 // copy the state
949 {
950 auto * result_ctx = (llama_sampler_greedy *) result->ctx;
951
952 GGML_UNUSED(ctx);
953 GGML_UNUSED(result_ctx);
954 }
955
956 return result;
957}
958
959static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
960 delete (llama_sampler_greedy *) smpl->ctx;
961}
962
963static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
964 cur_p->selected = 0;
965 for (size_t i = 1; i < cur_p->size; ++i) {
966 if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
967 cur_p->selected = i;
968 }
969 }
970}
971
972static bool llama_sampler_greedy_backend_init(
973 struct llama_sampler * smpl,
974 ggml_backend_buffer_type_t buft) {
975 auto * sctx = (llama_sampler_greedy *) smpl->ctx;
976
977 const bool res = llama_sampler_backend_support(smpl, buft);
978
979 sctx->init(res);
980
981 return res;
982}
983
984static void llama_sampler_greedy_backend_apply(
985 struct llama_sampler * smpl,
986 struct ggml_context * ctx,
987 struct ggml_cgraph * gf,
988 struct llama_sampler_data * data) {
989 GGML_UNUSED(gf);
990 GGML_UNUSED(smpl);
991
992 struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
993 ggml_set_name(curl, "greedy_argmax");
994
995 data->sampled = curl;
996}
997
998static struct llama_sampler_i llama_sampler_greedy_i = {
999 /* .name = */ llama_sampler_greedy_name,
1000 /* .accept = */ nullptr,
1001 /* .apply = */ llama_sampler_greedy_apply,
1002 /* .reset = */ llama_sampler_greedy_reset,
1003 /* .clone = */ llama_sampler_greedy_clone,
1004 /* .free = */ llama_sampler_greedy_free,
1005 /* .backend_init = */ llama_sampler_greedy_backend_init,
1006 /* .backend_accept = */ nullptr,
1007 /* .backend_apply = */ llama_sampler_greedy_backend_apply,
1008 /* .backend_set_input = */ nullptr,
1009};
1010
1011struct llama_sampler * llama_sampler_init_greedy() {
1012 return llama_sampler_init(
1013 /* .iface = */ &llama_sampler_greedy_i,
1014 /* .ctx = */ new llama_sampler_greedy {
1015 ("greedy"),
1016 }
1017 );
1018}
1019
1020// dist
1021
1022struct llama_sampler_dist : public llama_sampler_backend {
1023 const uint32_t seed;
1024 uint32_t seed_cur;
1025
1026 std::mt19937 rng;
1027
1028 ggml_tensor * inp_uniform;
1029};
1030
1031static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
1032 auto * sctx = (llama_sampler_dist *) smpl->ctx;
1033 return sctx->get_name();
1034}
1035
1036static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1037 auto * ctx = (llama_sampler_dist *) smpl->ctx;
1038
1039 // edge cases
1040 if (cur_p->size == 0) {
1041 cur_p->selected = -1;
1042 return;
1043 }
1044
1045 cur_p->selected = 0;
1046
1047 if (cur_p->size == 1) {
1048 cur_p->data[0].p = 1.0f;
1049 return;
1050 }
1051
1052 // max logit for numerical stability
1053 float max_l = cur_p->data[0].logit;
1054 if (!cur_p->sorted) {
1055 for (size_t i = 1; i < cur_p->size; ++i) {
1056 max_l = std::max(max_l, cur_p->data[i].logit);
1057 }
1058 }
1059
1060 // apply softmax to obtain the probabilities
1061 double sum_cum = 0.0f;
1062 for (size_t i = 0; i < cur_p->size; ++i) {
1063 float p = expf(cur_p->data[i].logit - max_l);
1064 cur_p->data[i].p = p;
1065 sum_cum += p;
1066 }
1067
1068#if 1
1069 // sample from the obtained probabilities and normalize the probs in a single pass
1070 // this is ~3x faster on Mac with full gpt-oss vocab than the version below
1071 //
1072 std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1073 const double rnd = dist(ctx->rng);
1074
1075 double sum_run = 0.0f;
1076 const double sum_tgt = sum_cum*rnd;
1077
1078 bool found = false;
1079 for (size_t i = 0; i < cur_p->size; ++i) {
1080 if (!found) {
1081 // accumulate probs until we reach the target sum
1082 sum_run += cur_p->data[i].p;
1083 if (sum_run >= sum_tgt) {
1084 cur_p->selected = i;
1085 found = true;
1086 }
1087 }
1088
1089 // normalize probs
1090 cur_p->data[i].p /= sum_cum;
1091 }
1092
1093 // fallback to the last token (don't think this can happen)
1094 assert(found);
1095 if (!found) {
1096 cur_p->selected = cur_p->size - 1;
1097 }
1098#else
1099 // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
1100 for (size_t i = 0; i < cur_p->size; ++i) {
1101 cur_p->data[i].p /= sum_cum;
1102 }
1103
1104 cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
1105#endif
1106}
1107
1108static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
1109 auto * ctx = (llama_sampler_dist *) smpl->ctx;
1110 ctx->seed_cur = get_rng_seed(ctx->seed);
1111 ctx->rng.seed(ctx->seed_cur);
1112}
1113
1114static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
1115 const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
1116 auto * result = llama_sampler_init_dist(ctx->seed);
1117
1118 // copy the state
1119 {
1120 auto * result_ctx = (llama_sampler_dist *) result->ctx;
1121
1122 result_ctx->rng = ctx->rng;
1123 }
1124
1125 return result;
1126}
1127
1128static void llama_sampler_dist_free(struct llama_sampler * smpl) {
1129 delete (llama_sampler_dist *) smpl->ctx;
1130}
1131
1132static bool llama_sampler_dist_backend_init(
1133 struct llama_sampler * smpl,
1134 ggml_backend_buffer_type_t buft) {
1135 auto * sctx = (llama_sampler_dist *) smpl->ctx;
1136
1137 const bool res = llama_sampler_backend_support(smpl, buft);
1138
1139 sctx->init(res);
1140
1141 return res;
1142}
1143
1144static void llama_sampler_dist_backend_apply(
1145 struct llama_sampler * smpl,
1146 struct ggml_context * ctx,
1147 struct ggml_cgraph * gf,
1148 struct llama_sampler_data * data) {
1149 GGML_UNUSED(gf);
1150
1151 auto * sctx = (llama_sampler_dist *) smpl->ctx;
1152
1153 sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
1154 ggml_set_name (sctx->inp_uniform, "uniform");
1155 ggml_set_input(sctx->inp_uniform);
1156
1157 struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
1158 ggml_set_name(probs, "dist_probs");
1159
1160 struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
1161 ggml_set_name(cumsum, "dist_cumsum");
1162
1163 // The uniform tensor has a random value and we subtract this tensor with
1164 // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
1165 // Recall that each entry in cumsum is the cumulative probability up to that
1166 // index so values stay negative while the cumulative total is below the
1167 // random value, and become zero/positive once the threshold is crossed.
1168 struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
1169 ggml_set_name(diff, "dist_cumsum");
1170
1171 // The ggml_step function produces a tensor where entries are 1 if the
1172 // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
1173 // the index where the cumulative probability exceeds the random value are 0,
1174 // and all entries after that are 1.
1175 struct ggml_tensor * mask = ggml_step(ctx, diff);
1176 ggml_set_name(mask, "dist_mask");
1177
1178 // Taking the sum of the mask gives us the sum of elements after the threshold
1179 // we are interested in.
1180 struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1181 ggml_set_name(idxf, "dist_index_f32");
1182
1183 // Use ggml_scale_bias to scale the index value by -1 and then add the size
1184 // of the mask to that value so we get the correct index ((-1 * idxf) + n).
1185 struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
1186 ggml_set_name(idx, "dist_index_i32");
1187
1188 // Map back to original vocab ids if a candidates tensor is available.
1189 struct ggml_tensor * sampled_token = idx;
1190 if (data->candidates != nullptr) {
1191 struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
1192
1193 sampled_token = ggml_get_rows(ctx, candidates, idx);
1194 ggml_set_name(sampled_token, "dist_sampled_token");
1195 }
1196
1197 data->sampled = sampled_token;
1198 data->probs = probs;
1199}
1200
1201static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
1202 auto * sctx = (llama_sampler_dist *) smpl->ctx;
1203
1204 GGML_ASSERT(sctx->inp_uniform != nullptr);
1205
1206 // We sample in double precision and cast to float to match rnd numbers of
1207 // llama_dampler_dist which uses double precision (sampling from
1208 // std::uniform_real_distribution<double> and
1209 // std::uniform_real_distribution<float> with same rng will produce
1210 // different sequences).
1211 std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1212 const float rnd = dist(sctx->rng);
1213
1214 ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
1215}
1216
1217static struct llama_sampler_i llama_sampler_dist_i = {
1218 /* .name = */ llama_sampler_dist_name,
1219 /* .accept = */ nullptr,
1220 /* .apply = */ llama_sampler_dist_apply,
1221 /* .reset = */ llama_sampler_dist_reset,
1222 /* .clone = */ llama_sampler_dist_clone,
1223 /* .free = */ llama_sampler_dist_free,
1224 /* .backend_init = */ llama_sampler_dist_backend_init,
1225 /* .backend_accept = */ nullptr,
1226 /* .backend_apply = */ llama_sampler_dist_backend_apply,
1227 /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
1228};
1229
1230struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
1231 auto seed_cur = get_rng_seed(seed);
1232 return llama_sampler_init(
1233 /* .iface = */ &llama_sampler_dist_i,
1234 /* .ctx = */ new llama_sampler_dist {
1235 ("dist"),
1236 /* .seed = */ seed,
1237 /* .seed_cur = */ seed_cur,
1238 /* .rng = */ std::mt19937(seed_cur),
1239 /* .inp_uniform = */ nullptr,
1240 }
1241 );
1242}
1243
1244// top-k
1245
1246struct llama_sampler_top_k : public llama_sampler_backend {
1247 const int32_t k;
1248};
1249
1250static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
1251 auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1252 return sctx->get_name();
1253}
1254
1255static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1256 auto * ctx = (llama_sampler_top_k *) smpl->ctx;
1257 llama_sampler_top_k_impl(cur_p, ctx->k);
1258}
1259
1260static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
1261 const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
1262 return llama_sampler_init_top_k(ctx->k);
1263}
1264
1265static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
1266 delete (llama_sampler_top_k *) smpl->ctx;
1267}
1268
1269static bool llama_sampler_top_k_backend_init(
1270 struct llama_sampler * smpl,
1271 ggml_backend_buffer_type_t buft) {
1272 auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1273
1274 const bool res = llama_sampler_backend_support(smpl, buft);
1275
1276 sctx->init(res);
1277
1278 return res;
1279}
1280
1281static void llama_sampler_top_k_backend_apply(
1282 struct llama_sampler * smpl,
1283 struct ggml_context * ctx,
1284 struct ggml_cgraph * gf,
1285 struct llama_sampler_data * data) {
1286 auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1287
1288 struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
1289 ggml_set_name(top_k, "top_k");
1290
1291 if (data->candidates) {
1292 struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1293 data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
1294 data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
1295 ggml_set_name(data->candidates, "top_k_candidates");
1296 } else {
1297 data->candidates = top_k;
1298 }
1299
1300 struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1301 struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
1302 data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
1303 ggml_set_name(top_k_rows, "top_k_rows");
1304
1305 GGML_UNUSED(gf);
1306}
1307
1308static struct llama_sampler_i llama_sampler_top_k_i = {
1309 /* .name = */ llama_sampler_top_k_name,
1310 /* .accept = */ nullptr,
1311 /* .apply = */ llama_sampler_top_k_apply,
1312 /* .reset = */ nullptr,
1313 /* .clone = */ llama_sampler_top_k_clone,
1314 /* .free = */ llama_sampler_top_k_free,
1315 /* .backend_init = */ llama_sampler_top_k_backend_init,
1316 /* .backend_accept = */ nullptr,
1317 /* .backend_apply = */ llama_sampler_top_k_backend_apply,
1318 /* .backend_set_input = */ nullptr,
1319};
1320
1321struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1322 const bool is_empty = (k <= 0);
1323
1324 if (is_empty) {
1325 return llama_sampler_init_empty("?top-k");
1326 }
1327
1328 return llama_sampler_init(
1329 /* .iface = */ &llama_sampler_top_k_i,
1330 /* .ctx = */ new llama_sampler_top_k {
1331 ("top-k"),
1332 /* .k = */ k,
1333 }
1334 );
1335}
1336
1337// top-p
1338
1339struct llama_sampler_top_p : public llama_sampler_backend {
1340 const float p;
1341 const size_t min_keep;
1342
1343 std::vector<llama_token_data> buf_sort;
1344};
1345
1346static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
1347 auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1348 return sctx->get_name();
1349}
1350
1351static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1352 auto * ctx = (llama_sampler_top_p *) smpl->ctx;
1353
1354 if (ctx->p >= 1.0f) {
1355 return;
1356 }
1357
1358 llama_sampler_softmax_impl(cur_p, false);
1359
1360 size_t k = cur_p->size;
1361 auto * pdata = cur_p->data;
1362
1363 auto & buf_sort = ctx->buf_sort;
1364
1365 // if not sorted, try adaptive top-k sorting
1366 if (!cur_p->sorted && cur_p->size > 1024) {
1367 k = std::min<size_t>(256, cur_p->size);
1368 llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1369 pdata = buf_sort.data();
1370 } else if (!cur_p->sorted) {
1371 // small candidates -> sort inplace
1372 llama_token_data_array_partial_sort_inplace(cur_p, k);
1373 }
1374
1375 // Compute the cumulative probabilities
1376 float cum_sum = 0.0f;
1377 size_t last_idx = cur_p->size;
1378
1379 for (size_t i = 0; i < cur_p->size; ++i) {
1380 cum_sum += pdata[i].p;
1381
1382 // Check if the running sum is at least p or if we have kept at least min_keep tokens
1383 // we set the last index to i+1 to indicate that the current iterate should be included in the set
1384 if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
1385 last_idx = i + 1;
1386 break;
1387 }
1388
1389 // we exceeded the current top-k heuristic -> increase k and continue
1390 if (!cur_p->sorted && i == k - 1) {
1391 k = cur_p->size;
1392 llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1393 pdata = buf_sort.data();
1394 }
1395 }
1396
1397 // Resize the output vector to keep only the top-p tokens
1398 if (!cur_p->sorted) {
1399 std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
1400 cur_p->sorted = true;
1401 }
1402
1403 cur_p->size = last_idx;
1404}
1405
1406static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
1407 const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
1408 return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
1409}
1410
1411static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
1412 delete (llama_sampler_top_p *) smpl->ctx;
1413}
1414
1415static bool llama_sampler_top_p_backend_init(
1416 struct llama_sampler * smpl,
1417 ggml_backend_buffer_type_t buft) {
1418 auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1419
1420 const bool res = llama_sampler_backend_support(smpl, buft);
1421
1422 sctx->init(res);
1423
1424 return res;
1425}
1426
1427static void llama_sampler_top_p_backend_apply(
1428 struct llama_sampler * smpl,
1429 struct ggml_context * ctx,
1430 struct ggml_cgraph * gf,
1431 struct llama_sampler_data * data) {
1432 auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1433
1434 auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
1435 GGML_ASSERT(ggml_nrows(a) == 1);
1436 struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
1437 struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
1438 return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
1439 };
1440
1441 // Get the sorted logits in descending order.
1442 struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
1443 ggml_set_name(sorted_idx, "top_p_sorted_idx");
1444
1445 // Do the sorting via reshape + get_rows
1446 struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
1447 ggml_set_name(sorted_logits, "top_p_sorted_logits");
1448
1449 struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
1450 ggml_set_name(softmax, "top_p_softmax");
1451
1452 // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
1453 if (data->candidates) {
1454 data->candidates = ggml_sort(data->candidates, sorted_idx);
1455 } else {
1456 data->candidates = sorted_idx;
1457 }
1458 ggml_set_name(data->candidates, "top_p_candidates");
1459
1460 // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
1461 struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
1462 ggml_set_name(cdf, "top_p_cdf");
1463
1464 // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
1465 struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
1466 ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
1467
1468 struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
1469 ggml_set_name(mask, "top_p_mask");
1470
1471 // Taking the sum of the mask gives us the sum of elements after the threshold
1472 // we are interested in.
1473 struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1474 ggml_set_name(idxf, "top_p_index_f32");
1475
1476 // prevent out-of-bounds access
1477 idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
1478
1479 // construct ones tensor to set the value in the mask
1480 struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
1481 ggml_set_name(ones, "top_p_ones");
1482
1483 // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
1484 struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
1485
1486 mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
1487 mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
1488
1489 // Apply -INFINITY bias for masked-out tokens
1490 // log(1) = 0 (keep), log(0) = -INF (discard)
1491 struct ggml_tensor * top_p_bias = ggml_log(ctx, mask);
1492 ggml_set_name(top_p_bias, "top_p_bias");
1493
1494 data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
1495 ggml_set_name(data->logits, "top_p_logits");
1496
1497 GGML_UNUSED(gf);
1498}
1499
1500static struct llama_sampler_i llama_sampler_top_p_i = {
1501 /* .name = */ llama_sampler_top_p_name,
1502 /* .accept = */ nullptr,
1503 /* .apply = */ llama_sampler_top_p_apply,
1504 /* .reset = */ nullptr,
1505 /* .clone = */ llama_sampler_top_p_clone,
1506 /* .free = */ llama_sampler_top_p_free,
1507 /* .backend_init = */ llama_sampler_top_p_backend_init,
1508 /* .backend_accept = */ nullptr,
1509 /* .backend_apply = */ llama_sampler_top_p_backend_apply,
1510 /* .backend_set_input = */ nullptr,
1511};
1512
1513struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
1514 const bool is_empty = p >= 1.0f;
1515
1516 if (is_empty) {
1517 return llama_sampler_init_empty("?top-p");
1518 }
1519
1520 return llama_sampler_init(
1521 /* .iface = */ &llama_sampler_top_p_i,
1522 /* .ctx = */ new llama_sampler_top_p {
1523 ("top-p"),
1524 /* .p = */ p,
1525 /* .min_keep = */ min_keep,
1526 /* .buf_sort = */ {},
1527 }
1528 );
1529}
1530
1531// min-p
1532
1533struct llama_sampler_min_p : public llama_sampler_backend {
1534 const float p;
1535 const size_t min_keep;
1536};
1537
1538static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
1539 auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1540 return sctx->get_name();
1541}
1542
1543static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1544 auto * ctx = (llama_sampler_min_p *) smpl->ctx;
1545
1546 if (ctx->p <= 0.0f || !cur_p->size) {
1547 return;
1548 }
1549
1550 bool min_p_applied = false;
1551
1552 // if the cur_p aren't sorted, try the unsorted implementation first
1553 if (!cur_p->sorted) {
1554 std::vector<llama_token_data> filtered_tokens;
1555
1556 float max_logit = -FLT_MAX;
1557 for (size_t i = 0; i < cur_p->size; ++i) {
1558 max_logit = std::max(max_logit, cur_p->data[i].logit);
1559 }
1560 const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
1561
1562 for (size_t i = 0; i < cur_p->size; ++i) {
1563 if (cur_p->data[i].logit >= min_logit) {
1564 filtered_tokens.push_back(cur_p->data[i]);
1565 }
1566 }
1567
1568 // if we have enough values the operation was a success
1569 if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
1570 std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
1571 cur_p->size = filtered_tokens.size();
1572 min_p_applied = true;
1573 }
1574 }
1575
1576 // if the cur_p are sorted or the unsorted implementation failed, use this implementation
1577 if (!min_p_applied) {
1578 // Sort the logits in descending order
1579 if (!cur_p->sorted) {
1580 llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
1581 }
1582
1583 const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
1584 size_t i = 1; // first token always matches
1585
1586 for (; i < cur_p->size; ++i) {
1587 if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
1588 break; // prob too small
1589 }
1590 }
1591
1592 // Resize the output vector to keep only the matching tokens
1593 cur_p->size = i;
1594 }
1595}
1596
1597static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
1598 const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
1599 return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
1600}
1601
1602static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
1603 delete (llama_sampler_min_p *) smpl->ctx;
1604}
1605
1606static bool llama_sampler_min_p_backend_init(
1607 struct llama_sampler * smpl,
1608 ggml_backend_buffer_type_t buft) {
1609 auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1610
1611 const bool res = llama_sampler_backend_support(smpl, buft);
1612
1613 sctx->init(res);
1614
1615 return res;
1616}
1617
1618static void llama_sampler_min_p_backend_apply(
1619 struct llama_sampler * smpl,
1620 struct ggml_context * ctx,
1621 struct ggml_cgraph * gf,
1622 struct llama_sampler_data * data) {
1623 auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1624
1625 struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1626 ggml_set_name(max_idx, "max_idx");
1627
1628 struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1629 ggml_set_name(logits_rows, "logits_rows");
1630
1631 struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
1632 ggml_set_name(max_logit, "max_logit");
1633
1634 // Calculate the threshold value.
1635 struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
1636 ggml_set_name(threshold, "min_p_threshold");
1637
1638 // Subtract the threshold from logits.
1639 struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
1640
1641 // Create a mask where logits below the threshold are 0 (discard),
1642 // and others are 1 (keep).
1643 struct ggml_tensor * mask = ggml_step(ctx, sub);
1644 ggml_set_name(mask, "min_p_mask");
1645
1646 // Apply -INFINITY bias for masked-out tokens
1647 // log(1) = 0 (keep), log(0) = -INF (discard)
1648 struct ggml_tensor * min_p_bias = ggml_log(ctx, mask);
1649 ggml_set_name(min_p_bias, "min_p_bias");
1650
1651 data->logits = ggml_add(ctx, data->logits, min_p_bias);
1652 ggml_set_name(data->logits, "min_p_logits");
1653
1654 GGML_UNUSED(gf);
1655}
1656
1657static struct llama_sampler_i llama_sampler_min_p_i = {
1658 /* .name = */ llama_sampler_min_p_name,
1659 /* .accept = */ nullptr,
1660 /* .apply = */ llama_sampler_min_p_apply,
1661 /* .reset = */ nullptr,
1662 /* .clone = */ llama_sampler_min_p_clone,
1663 /* .free = */ llama_sampler_min_p_free,
1664 /* .backend_init = */ llama_sampler_min_p_backend_init,
1665 /* .backend_accept = */ nullptr,
1666 /* .backend_apply = */ llama_sampler_min_p_backend_apply,
1667 /* .backend_set_input = */ nullptr,
1668};
1669
1670struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
1671 const bool is_empty = (p <= 0.0f);
1672
1673 if (is_empty) {
1674 return llama_sampler_init_empty("?min-p");
1675 }
1676
1677 return llama_sampler_init(
1678 /* .iface = */ &llama_sampler_min_p_i,
1679 /* .ctx = */ new llama_sampler_min_p {
1680 ("min-p"),
1681 /* .p = */ p,
1682 /* .min_keep = */ min_keep,
1683 }
1684 );
1685}
1686
1687// typical
1688
1689struct llama_sampler_typical {
1690 const float p;
1691 const size_t min_keep;
1692};
1693
1694static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
1695 return "typical";
1696}
1697
1698static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1699 auto * ctx = (llama_sampler_typical *) smpl->ctx;
1700
1701 // Reference implementation:
1702 // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
1703 if (ctx->p >= 1.0f) {
1704 return;
1705 }
1706
1707 // Compute the softmax of logits and calculate entropy
1708 llama_sampler_softmax_impl(cur_p, true);
1709
1710 float entropy = 0.0f;
1711 for (size_t i = 0; i < cur_p->size; ++i) {
1712 entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
1713 }
1714
1715 // Compute the absolute difference between negative log probability and entropy for each candidate
1716 std::vector<float> shifted_scores;
1717 for (size_t i = 0; i < cur_p->size; ++i) {
1718 float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
1719 shifted_scores.push_back(shifted_score);
1720 }
1721
1722 // Sort tokens based on the shifted_scores and their corresponding indices
1723 std::vector<size_t> indices(cur_p->size);
1724 std::iota(indices.begin(), indices.end(), 0);
1725
1726 std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
1727 return shifted_scores[a] < shifted_scores[b];
1728 });
1729
1730 // Compute the cumulative probabilities
1731 float cum_sum = 0.0f;
1732 size_t last_idx = indices.size();
1733
1734 for (size_t i = 0; i < indices.size(); ++i) {
1735 size_t idx = indices[i];
1736 cum_sum += cur_p->data[idx].p;
1737
1738 // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
1739 if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
1740 last_idx = i + 1;
1741 break;
1742 }
1743 }
1744
1745 // Resize the output vector to keep only the locally typical tokens
1746 std::vector<llama_token_data> cur_p_new;
1747 for (size_t i = 0; i < last_idx; ++i) {
1748 size_t idx = indices[i];
1749 cur_p_new.push_back(cur_p->data[idx]);
1750 }
1751
1752 // Replace the data in cur_p with the cur_p_new data
1753 std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
1754 cur_p->size = cur_p_new.size();
1755 cur_p->sorted = false;
1756}
1757
1758static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
1759 const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
1760 return llama_sampler_init_typical(ctx->p, ctx->min_keep);
1761}
1762
1763static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1764 delete (llama_sampler_typical *) smpl->ctx;
1765}
1766
1767static struct llama_sampler_i llama_sampler_typical_i = {
1768 /* .name = */ llama_sampler_typical_name,
1769 /* .accept = */ nullptr,
1770 /* .apply = */ llama_sampler_typical_apply,
1771 /* .reset = */ nullptr,
1772 /* .clone = */ llama_sampler_typical_clone,
1773 /* .free = */ llama_sampler_typical_free,
1774 /* .backend_init = */ nullptr,
1775 /* .backend_accept = */ nullptr,
1776 /* .backend_apply = */ nullptr,
1777 /* .backend_set_input = */ nullptr,
1778};
1779
1780struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1781 const bool is_empty = (p >= 1.0f);
1782
1783 if (is_empty) {
1784 return llama_sampler_init_empty("?typical");
1785 }
1786
1787 return llama_sampler_init(
1788 /* .iface = */ &llama_sampler_typical_i,
1789 /* .ctx = */ new llama_sampler_typical {
1790 /* .p = */ p,
1791 /* .min_keep = */ min_keep,
1792 }
1793 );
1794}
1795
1796// temp
1797
1798struct llama_sampler_temp : public llama_sampler_backend {
1799 const float temp;
1800};
1801
1802static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
1803 auto * sctx = (llama_sampler_temp *) smpl->ctx;
1804 return sctx->get_name();
1805}
1806
1807static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1808 const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1809
1810 llama_sampler_temp_impl(cur_p, ctx->temp);
1811}
1812
1813static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1814 const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1815 return llama_sampler_init_temp(ctx->temp);
1816}
1817
1818static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1819 delete (llama_sampler_temp *) smpl->ctx;
1820}
1821
1822static void llama_sampler_backend_temp_sampling(
1823 struct ggml_context * ctx,
1824 struct ggml_cgraph * gf,
1825 struct llama_sampler_data * data,
1826 float temp) {
1827 if (temp <= 0.0f) {
1828 // Find the most probable token index.
1829 struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1830 ggml_set_name(max_idx, "temp_max_idx");
1831
1832 if (data->candidates) {
1833 struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1834 data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
1835 } else {
1836 data->candidates = max_idx;
1837 }
1838
1839 struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1840 data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
1841
1842 return;
1843 }
1844
1845 data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
1846
1847 GGML_UNUSED(gf);
1848}
1849
1850static bool llama_sampler_temp_backend_init(
1851 struct llama_sampler * smpl,
1852 ggml_backend_buffer_type_t buft) {
1853 auto * sctx = (llama_sampler_temp *) smpl->ctx;
1854
1855 const bool res = llama_sampler_backend_support(smpl, buft);
1856
1857 sctx->init(res);
1858
1859 return res;
1860}
1861
1862static void llama_sampler_temp_backend_apply(
1863 struct llama_sampler * smpl,
1864 struct ggml_context * ctx,
1865 struct ggml_cgraph * gf,
1866 struct llama_sampler_data * data) {
1867 auto * sctx = (llama_sampler_temp *) smpl->ctx;
1868 llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
1869}
1870
1871static struct llama_sampler_i llama_sampler_temp_i = {
1872 /* .name = */ llama_sampler_temp_name,
1873 /* .accept = */ nullptr,
1874 /* .apply = */ llama_sampler_temp_apply,
1875 /* .reset = */ nullptr,
1876 /* .clone = */ llama_sampler_temp_clone,
1877 /* .free = */ llama_sampler_temp_free,
1878 /* .backend_init = */ llama_sampler_temp_backend_init,
1879 /* .backend_accept = */ nullptr,
1880 /* .backend_apply = */ llama_sampler_temp_backend_apply,
1881 /* .backend_set_input = */ nullptr,
1882};
1883
1884struct llama_sampler * llama_sampler_init_temp(float temp) {
1885 const bool is_empty = temp == 1.0f;
1886
1887 if (is_empty) {
1888 return llama_sampler_init_empty("?temp");
1889 }
1890
1891 return llama_sampler_init(
1892 /* .iface = */ &llama_sampler_temp_i,
1893 /* .ctx = */ new llama_sampler_temp {
1894 ("temp"),
1895 /*.temp = */ temp,
1896 }
1897 );
1898}
1899
1900// temp-ext
1901
1902struct llama_sampler_temp_ext : public llama_sampler_backend {
1903 const float temp;
1904 const float delta;
1905 const float exponent;
1906};
1907
1908static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
1909 auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1910 return sctx->get_name();
1911}
1912
1913static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1914 auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1915 if (ctx->delta > 0) {
1916 const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1917 const float max_temp = ctx->temp + ctx->delta;
1918
1919 float exponent_val = ctx->exponent;
1920
1921 // no need to do anything if there is only one (or zero) candidates
1922 if (cur_p->size <= 1) {
1923 return;
1924 }
1925
1926 // Calculate maximum possible entropy
1927 float max_entropy = -logf(1.0f / cur_p->size);
1928
1929 llama_sampler_softmax_impl(cur_p, true);
1930
1931 // Calculate entropy of the softmax probabilities
1932 float entropy = 0.0f;
1933 for (size_t i = 0; i < cur_p->size; ++i) {
1934 float prob = cur_p->data[i].p;
1935 if (prob > 0.0f) { // Ensure no log(0)
1936 entropy -= prob * logf(prob);
1937 }
1938 }
1939
1940 // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1941 float normalized_entropy = entropy / max_entropy;
1942
1943 // Map the normalized entropy to the desired temperature range using the power function
1944 float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1945
1946 #ifdef DEBUG
1947 LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1948 LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1949 LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1950 LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1951 LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1952 LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1953 #endif
1954
1955 // Apply the dynamically calculated temperature scaling
1956 llama_sampler_temp_impl(cur_p, dyn_temp);
1957
1958 // Re-compute softmax probabilities after scaling logits with dynamic temperature
1959 const double max_l_double = cur_p->data[0].logit;
1960
1961 double cum_sum_double = 0.0;
1962 for (size_t i = 0; i < cur_p->size; ++i) {
1963 double p = exp(cur_p->data[i].logit - max_l_double);
1964 cur_p->data[i].p = p; // Store the scaled probability
1965 cum_sum_double += p;
1966 }
1967
1968 for (size_t i = 0; i < cur_p->size; ++i) {
1969 cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1970 }
1971
1972 #ifdef DEBUG
1973 // Print the updated top 25 probabilities after temperature scaling
1974 LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1975 for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1976 LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1977 }
1978 #endif
1979 } else {
1980 llama_sampler_temp_impl(cur_p, ctx->temp);
1981 }
1982}
1983
1984static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1985 const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1986 return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
1987}
1988
1989static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1990 delete (llama_sampler_temp_ext *) smpl->ctx;
1991}
1992
1993static bool llama_sampler_temp_ext_backend_init(
1994 struct llama_sampler * smpl,
1995 ggml_backend_buffer_type_t buft) {
1996 auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1997
1998 const bool res = llama_sampler_backend_support(smpl, buft);
1999
2000 sctx->init(res);
2001
2002 return res;
2003}
2004
2005static void llama_sampler_temp_ext_backend_apply(
2006 struct llama_sampler * smpl,
2007 struct ggml_context * ctx,
2008 struct ggml_cgraph * gf,
2009 struct llama_sampler_data * data) {
2010 auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2011
2012 // Revert to standard temperature scaling if delta or temp are non-positive.
2013 if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
2014 llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
2015 return;
2016 }
2017
2018 // Calculate min_temp, max_temp, and max_entropy.
2019 const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
2020 const float max_temp = sctx->temp + sctx->delta;
2021 const float max_entropy = logf(data->logits->ne[0]);
2022
2023 // Calculate the probabilities.
2024 struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
2025 ggml_set_name(probs, "temp_ext_softmax_probs");
2026
2027 // Clamp probabilities to avoid log(0) which would give -inf
2028 struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
2029 ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
2030
2031 // Calculate the entropy, entropy = -ฮฃ(p * log(p)).
2032 struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
2033 struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
2034 struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
2035 struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
2036 ggml_set_name(log_probs, "temp_ext_log_probs");
2037 ggml_set_name(p_log_p, "temp_ext_p_log_p");
2038 ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
2039 ggml_set_name(entropy, "temp_ext_entropy");
2040
2041 // Normalize the entropy, norm_entropy = entropy / max_entropy
2042 struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
2043 ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
2044
2045 // Calculate the dynamic temperature:
2046 // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
2047 //
2048 // Calculate powf(normalized_entropy, exponent) as
2049 // norm_entropy^exponent = exp(exponent * log(norm_entropy))
2050 struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
2051 struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
2052 struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
2053 // With pow_entropy computed we can now compute dyn_temp, scaling by
2054 // (max_temp - min_temp) and then adding min_temp.
2055 struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
2056 ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
2057 ggml_set_name(scaled_log, "temp_ext_scaled_log");
2058 ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
2059 ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
2060
2061 // Scale the logits by the dynamic temperature
2062 struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
2063 ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
2064
2065 data->logits = scaled_logits;
2066}
2067
2068static struct llama_sampler_i llama_sampler_temp_ext_i = {
2069 /* .name = */ llama_sampler_temp_ext_name,
2070 /* .accept = */ nullptr,
2071 /* .apply = */ llama_sampler_temp_ext_apply,
2072 /* .reset = */ nullptr,
2073 /* .clone = */ llama_sampler_temp_ext_clone,
2074 /* .free = */ llama_sampler_temp_ext_free,
2075 /* .backend_init = */ llama_sampler_temp_ext_backend_init,
2076 /* .backend_accept = */ nullptr,
2077 /* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
2078 /* .backend_set_input = */ nullptr,
2079};
2080
2081struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
2082 const bool is_empty = temp == 1.0f && delta <= 0.0f;
2083
2084 if (is_empty) {
2085 return llama_sampler_init_empty("?temp-ext");
2086 }
2087
2088 auto * res = llama_sampler_init(
2089 /* .iface = */ &llama_sampler_temp_ext_i,
2090 /* .ctx = */ new llama_sampler_temp_ext {
2091 ("temp-ext"),
2092 /* .temp = */ temp,
2093 /* .delta = */ delta,
2094 /* .exponent = */ exponent,
2095 }
2096 );
2097
2098 return res;
2099}
2100
2101// xtc
2102
2103struct llama_sampler_xtc {
2104 const float probability;
2105 const float threshold;
2106 const size_t min_keep;
2107
2108 const uint32_t seed;
2109 uint32_t seed_cur;
2110
2111 std::mt19937 rng;
2112};
2113
2114static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
2115 return "xtc";
2116}
2117
2118static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2119 auto * ctx = (llama_sampler_xtc *) smpl->ctx;
2120
2121 if (ctx->probability <= 0.0f
2122 || ctx->threshold > 0.5f
2123 || cur_p->size < 2) {
2124 return;
2125 }
2126
2127 std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
2128 float chance = distribution(ctx->rng);
2129 if (chance > ctx->probability) {
2130 return;
2131 }
2132
2133 llama_sampler_softmax_impl(cur_p, true);
2134
2135 int pos_last = 0;
2136
2137 for (size_t i = 0; i < cur_p->size; ++i) {
2138 if (cur_p->data[i].p >= ctx->threshold) {
2139 pos_last = i;
2140 } else {
2141 break;
2142 }
2143 }
2144
2145 if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
2146 cur_p->data += pos_last;
2147 cur_p->size -= pos_last;
2148 }
2149}
2150
2151static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
2152 const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
2153 auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
2154
2155 // copy the state
2156 {
2157 auto * result_ctx = (llama_sampler_xtc *) result->ctx;
2158
2159 result_ctx->rng = ctx->rng;
2160 }
2161
2162 return result;
2163}
2164
2165static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
2166 delete (llama_sampler_xtc *) smpl->ctx;
2167}
2168
2169static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
2170 auto * ctx = (llama_sampler_xtc *) smpl->ctx;
2171 ctx->seed_cur = get_rng_seed(ctx->seed);
2172 ctx->rng.seed(ctx->seed_cur);
2173}
2174
2175static struct llama_sampler_i llama_sampler_xtc_i = {
2176 /* .name = */ llama_sampler_xtc_name,
2177 /* .accept = */ nullptr,
2178 /* .apply = */ llama_sample_xtc_apply,
2179 /* .reset = */ llama_sampler_xtc_reset,
2180 /* .clone = */ llama_sampler_xtc_clone,
2181 /* .free = */ llama_sampler_xtc_free,
2182 /* .backend_init = */ nullptr,
2183 /* .backend_accept = */ nullptr,
2184 /* .backend_apply = */ nullptr,
2185 /* .backend_set_input = */ nullptr,
2186};
2187
2188struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
2189 const bool is_empty = (p <= 0.0f || t > 0.5f);
2190
2191 if (is_empty) {
2192 return llama_sampler_init_empty("?xtc");
2193 }
2194
2195 const auto seed_cur = get_rng_seed(seed);
2196
2197 return llama_sampler_init(
2198 /* .iface = */ &llama_sampler_xtc_i,
2199 /* .ctx = */ new llama_sampler_xtc {
2200 /* .probability = */ p,
2201 /* .threshold = */ t,
2202 /* .min_keep = */ min_keep,
2203 /* .seed = */ seed,
2204 /* .seed_cur = */ seed_cur,
2205 /* .rng = */ std::mt19937(seed_cur),
2206 }
2207 );
2208}
2209
2210// mirostat
2211
2212struct llama_sampler_mirostat {
2213 const int32_t n_vocab;
2214
2215 const uint32_t seed;
2216 uint32_t seed_cur;
2217
2218 const float tau;
2219 const float eta;
2220
2221 const int32_t m;
2222
2223 float mu;
2224
2225 std::mt19937 rng;
2226};
2227
2228static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
2229 return "mirostat";
2230}
2231
2232static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2233 auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
2234
2235 llama_sampler_softmax_impl(cur_p, true);
2236
2237 // Estimate s_hat using the most probable m tokens
2238 float s_hat = 0.0;
2239 float sum_ti_bi = 0.0;
2240 float sum_ti_sq = 0.0;
2241 for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
2242 float t_i = logf(float(i + 2) / float(i + 1));
2243 float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
2244 sum_ti_bi += t_i * b_i;
2245 sum_ti_sq += t_i * t_i;
2246 }
2247 s_hat = sum_ti_bi / sum_ti_sq;
2248
2249 // Compute k from the estimated s_hat and target surprise value
2250 float epsilon_hat = s_hat - 1;
2251 float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
2252
2253 llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
2254
2255 llama_sampler_softmax_impl(cur_p, true);
2256
2257 const int idx = llama_sample_dist(cur_p, ctx->rng);
2258
2259 cur_p->selected = idx;
2260
2261 float observed_surprise = -log2f(cur_p->data[idx].p);
2262 float e = observed_surprise - ctx->tau;
2263
2264 // Update mu using the learning rate and error
2265 ctx->mu = ctx->mu - ctx->eta * e;
2266}
2267
2268static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
2269 const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
2270 auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
2271
2272 // copy the state
2273 {
2274 auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
2275
2276 result_ctx->mu = ctx->mu;
2277 result_ctx->rng = ctx->rng;
2278 }
2279
2280 return result;
2281}
2282
2283static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
2284 auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
2285 ctx->mu = 2.0f*ctx->tau;
2286 ctx->seed_cur = get_rng_seed(ctx->seed);
2287 ctx->rng.seed(ctx->seed_cur);
2288}
2289
2290static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
2291 delete (llama_sampler_mirostat *) smpl->ctx;
2292}
2293
2294static struct llama_sampler_i llama_sampler_mirostat_i = {
2295 /* .name = */ llama_sampler_mirostat_name,
2296 /* .accept = */ nullptr,
2297 /* .apply = */ llama_sampler_mirostat_apply,
2298 /* .reset = */ llama_sampler_mirostat_reset,
2299 /* .clone = */ llama_sampler_mirostat_clone,
2300 /* .free = */ llama_sampler_mirostat_free,
2301 /* .backend_init = */ nullptr,
2302 /* .backend_accept = */ nullptr,
2303 /* .backend_apply = */ nullptr,
2304 /* .backend_set_input = */ nullptr,
2305};
2306
2307struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
2308 const auto seed_cur = get_rng_seed(seed);
2309
2310 return llama_sampler_init(
2311 /* .iface = */ &llama_sampler_mirostat_i,
2312 /* .ctx = */ new llama_sampler_mirostat {
2313 /* .n_vocab = */ n_vocab,
2314 /* .seed = */ seed,
2315 /* .seed_cur = */ seed_cur,
2316 /* .tau = */ tau,
2317 /* .eta = */ eta,
2318 /* .m = */ m,
2319 /* .mu = */ 2.0f*tau,
2320 /* .rng = */ std::mt19937(seed_cur),
2321 }
2322 );
2323}
2324
2325// mirostat v2
2326
2327struct llama_sampler_mirostat_v2 {
2328 const uint32_t seed;
2329 uint32_t seed_cur;
2330
2331 const float tau;
2332 const float eta;
2333
2334 float mu;
2335
2336 std::mt19937 rng;
2337};
2338
2339static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
2340 return "mirostat-v2";
2341}
2342
2343static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2344 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
2345
2346 llama_sampler_softmax_impl(cur_p, true);
2347
2348 // Truncate the words with surprise values greater than mu
2349 cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
2350 return -log2f(candidate.p) > ctx->mu;
2351 }));
2352
2353 if (cur_p->size == 0) {
2354 cur_p->size = 1;
2355 }
2356
2357 // Normalize the probabilities of the remaining words
2358 llama_sampler_softmax_impl(cur_p, true);
2359
2360 const int idx = llama_sample_dist(cur_p, ctx->rng);
2361
2362 cur_p->selected = idx;
2363
2364 float observed_surprise = -log2f(cur_p->data[idx].p);
2365 float e = observed_surprise - ctx->tau;
2366
2367 // Update mu using the learning rate and error
2368 ctx->mu = ctx->mu - ctx->eta * e;
2369}
2370
2371static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
2372 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
2373 ctx->mu = 2.0f*ctx->tau;
2374 ctx->seed_cur = get_rng_seed(ctx->seed);
2375 ctx->rng.seed(ctx->seed_cur);
2376}
2377
2378static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
2379 const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
2380
2381 auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
2382
2383 // copy the state
2384 {
2385 auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
2386
2387 result_ctx->mu = ctx->mu;
2388 result_ctx->rng = ctx->rng;
2389 }
2390
2391 return result;
2392}
2393
2394static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
2395 delete (llama_sampler_mirostat_v2 *) smpl->ctx;
2396}
2397
2398static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
2399 /* .name = */ llama_sampler_mirostat_v2_name,
2400 /* .accept = */ nullptr,
2401 /* .apply = */ llama_sampler_mirostat_v2_apply,
2402 /* .reset = */ llama_sampler_mirostat_v2_reset,
2403 /* .clone = */ llama_sampler_mirostat_v2_clone,
2404 /* .free = */ llama_sampler_mirostat_v2_free,
2405 /* .backend_init = */ nullptr,
2406 /* .backend_accept = */ nullptr,
2407 /* .backend_apply = */ nullptr,
2408 /* .backend_set_input = */ nullptr,
2409};
2410
2411struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
2412 auto seed_cur = get_rng_seed(seed);
2413 return llama_sampler_init(
2414 /* .iface = */ &llama_sampler_mirostat_v2_i,
2415 /* .ctx = */ new llama_sampler_mirostat_v2 {
2416 /* .seed = */ seed,
2417 /* .seed_cur = */ seed_cur,
2418 /* .tau = */ tau,
2419 /* .eta = */ eta,
2420 /* .mu = */ 2.0f*tau,
2421 /* .rng = */ std::mt19937(seed_cur),
2422 }
2423 );
2424}
2425
2426// grammar
2427
2428struct llama_sampler_grammar {
2429 const struct llama_vocab * vocab;
2430
2431 std::string grammar_str;
2432 std::string grammar_root;
2433
2434 struct llama_grammar * grammar;
2435};
2436
2437static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
2438 return "grammar";
2439}
2440
2441static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
2442 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2443 if (ctx->grammar) {
2444 llama_grammar_accept_impl(*ctx->grammar, token);
2445 }
2446}
2447
2448static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2449 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2450 if (ctx->grammar) {
2451 llama_grammar_apply_impl(*ctx->grammar, cur_p);
2452 }
2453}
2454
2455// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
2456static struct llama_sampler * llama_sampler_init_grammar_impl(
2457 const struct llama_vocab * vocab,
2458 const char * grammar_str,
2459 const char * grammar_root,
2460 bool lazy,
2461 const char ** trigger_words,
2462 size_t num_trigger_words,
2463 const llama_token * trigger_tokens,
2464 size_t num_trigger_tokens,
2465 const char ** trigger_patterns,
2466 size_t num_trigger_patterns);
2467
2468static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
2469 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2470 if (!ctx->grammar) {
2471 return;
2472 }
2473
2474 std::vector<const char *> trigger_patterns_c;
2475 trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
2476 for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
2477 trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
2478 }
2479
2480 auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
2481 ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
2482 ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
2483
2484 llama_grammar_free_impl(ctx->grammar);
2485 ctx->grammar = grammar_new;
2486}
2487
2488static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
2489 const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
2490
2491 auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
2492 GGML_ASSERT(result);
2493
2494 // copy the state
2495 {
2496 auto * result_ctx = (llama_sampler_grammar *) result->ctx;
2497
2498 if (ctx->grammar) {
2499 result_ctx->grammar_str = ctx->grammar_str;
2500 result_ctx->grammar_root = ctx->grammar_root;
2501
2502 result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
2503 }
2504 }
2505
2506 return result;
2507}
2508
2509static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
2510 const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2511
2512 if (ctx->grammar) {
2513 llama_grammar_free_impl(ctx->grammar);
2514 }
2515
2516 delete ctx;
2517}
2518
2519static struct llama_sampler_i llama_sampler_grammar_i = {
2520 /* .name = */ llama_sampler_grammar_name,
2521 /* .accept = */ llama_sampler_grammar_accept_impl,
2522 /* .apply = */ llama_sampler_grammar_apply,
2523 /* .reset = */ llama_sampler_grammar_reset,
2524 /* .clone = */ llama_sampler_grammar_clone,
2525 /* .free = */ llama_sampler_grammar_free,
2526 /* .backend_init = */ nullptr,
2527 /* .backend_accept = */ nullptr,
2528 /* .backend_apply = */ nullptr,
2529 /* .backend_set_input = */ nullptr,
2530};
2531
2532static struct llama_sampler * llama_sampler_init_grammar_impl(
2533 const struct llama_vocab * vocab,
2534 const char * grammar_str,
2535 const char * grammar_root,
2536 bool lazy,
2537 const char ** trigger_words,
2538 size_t num_trigger_words,
2539 const llama_token * trigger_tokens,
2540 size_t num_trigger_tokens,
2541 const char ** trigger_patterns,
2542 size_t num_trigger_patterns) {
2543 auto * ctx = new llama_sampler_grammar;
2544
2545 if (grammar_str != nullptr && grammar_str[0] != '\0') {
2546 std::string trigger_pattern;
2547 llama_grammar * grammar = nullptr;
2548 // TODO: remove trigger_words support.
2549 if (trigger_words != nullptr && num_trigger_words > 0) {
2550 GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
2551 trigger_pattern = "[\\s\\S]*?(";
2552 for (size_t i = 0; i < num_trigger_words; ++i) {
2553 static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
2554 if (i > 0) {
2555 trigger_pattern += "|";
2556 }
2557 trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
2558 }
2559 trigger_pattern += ")[\\s\\S]*";
2560
2561 std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
2562 grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
2563 } else {
2564 grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
2565 }
2566 *ctx = {
2567 /* .vocab = */ vocab,
2568 /* .grammar_str = */ grammar_str,
2569 /* .grammar_root = */ grammar_root,
2570 /* .grammar = */ grammar,
2571 };
2572 if (!ctx->grammar) {
2573 delete ctx;
2574 return nullptr;
2575 }
2576 } else {
2577 *ctx = {
2578 /* .vocab = */ vocab,
2579 /* .grammar_str = */ {},
2580 /* .grammar_root = */ {},
2581 /* .grammar = */ nullptr,
2582 };
2583 }
2584
2585 return llama_sampler_init(
2586 /* .iface = */ &llama_sampler_grammar_i,
2587 /* .ctx = */ ctx
2588 );
2589}
2590
2591struct llama_sampler * llama_sampler_init_grammar(
2592 const struct llama_vocab * vocab,
2593 const char * grammar_str,
2594 const char * grammar_root) {
2595 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
2596}
2597
2598struct llama_sampler * llama_sampler_init_grammar_lazy(
2599 const struct llama_vocab * vocab,
2600 const char * grammar_str,
2601 const char * grammar_root,
2602 const char ** trigger_words,
2603 size_t num_trigger_words,
2604 const llama_token * trigger_tokens,
2605 size_t num_trigger_tokens) {
2606 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
2607}
2608
2609struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
2610 const struct llama_vocab * vocab,
2611 const char * grammar_str,
2612 const char * grammar_root,
2613 const char ** trigger_patterns,
2614 size_t num_trigger_patterns,
2615 const llama_token * trigger_tokens,
2616 size_t num_trigger_tokens) {
2617 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
2618}
2619
2620// penalties
2621
2622struct llama_sampler_penalties {
2623 const int32_t penalty_last_n;
2624 const float penalty_repeat;
2625 const float penalty_freq;
2626 const float penalty_present;
2627
2628 ring_buffer<llama_token> prev;
2629
2630 // a frequency map to count token occurrences
2631 std::unordered_map<llama_token, int> token_count;
2632};
2633
2634static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
2635 return "penalties";
2636}
2637
2638static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
2639 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2640 if (ctx->penalty_last_n == 0) {
2641 return;
2642 }
2643
2644 ctx->token_count[token]++;
2645
2646 // if the ring buffer is full, remove the oldest token
2647 if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
2648 const auto old = ctx->prev.front();
2649
2650 ctx->token_count[old]--;
2651 if (ctx->token_count[old] == 0) {
2652 ctx->token_count.erase(old);
2653 }
2654 }
2655
2656 ctx->prev.push_back(token);
2657
2658#if 0
2659 // sanity check
2660 std::unordered_map<llama_token, int> tmp;
2661 for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
2662 tmp[ctx->prev.rat(i)]++;
2663 }
2664
2665 assert(ctx->token_count == tmp);
2666#endif
2667}
2668
2669static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2670 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2671
2672 if ((ctx->penalty_last_n == 0) ||
2673 (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
2674 return;
2675 }
2676
2677 // Apply frequency and presence penalties to the cur_p
2678 for (size_t i = 0; i < cur_p->size; ++i) {
2679 const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
2680 if (token_iter == ctx->token_count.end()) {
2681 continue;
2682 }
2683
2684 const int count = token_iter->second;
2685
2686 assert(count > 0 && count <= ctx->penalty_last_n);
2687
2688 // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
2689 // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
2690 if (cur_p->data[i].logit <= 0) {
2691 cur_p->data[i].logit *= ctx->penalty_repeat;
2692 } else {
2693 cur_p->data[i].logit /= ctx->penalty_repeat;
2694 }
2695
2696 cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
2697 }
2698
2699 cur_p->sorted = false;
2700}
2701
2702static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
2703 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2704 ctx->prev.clear();
2705 ctx->token_count.clear();
2706}
2707
2708static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
2709 const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
2710 auto * result = llama_sampler_init_penalties(
2711 ctx->penalty_last_n,
2712 ctx->penalty_repeat,
2713 ctx->penalty_freq,
2714 ctx->penalty_present);
2715
2716 // copy the state
2717 {
2718 auto * result_ctx = (llama_sampler_penalties *) result->ctx;
2719
2720 result_ctx->prev = ctx->prev;
2721 }
2722
2723 return result;
2724}
2725
2726static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
2727 delete (llama_sampler_penalties *) smpl->ctx;
2728}
2729
2730static struct llama_sampler_i llama_sampler_penalties_i = {
2731 /* .name = */ llama_sampler_penalties_name,
2732 /* .accept = */ llama_sampler_penalties_accept,
2733 /* .apply = */ llama_sampler_penalties_apply,
2734 /* .reset = */ llama_sampler_penalties_reset,
2735 /* .clone = */ llama_sampler_penalties_clone,
2736 /* .free = */ llama_sampler_penalties_free,
2737 /* .backend_init = */ nullptr,
2738 /* .backend_accept = */ nullptr,
2739 /* .backend_apply = */ nullptr,
2740 /* .backend_set_input = */ nullptr,
2741};
2742
2743struct llama_sampler * llama_sampler_init_penalties(
2744 int32_t penalty_last_n,
2745 float penalty_repeat,
2746 float penalty_freq,
2747 float penalty_present) {
2748 penalty_last_n = std::max(penalty_last_n, 0);
2749
2750 const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
2751
2752 if (is_empty) {
2753 return llama_sampler_init_empty("?penalties");
2754 }
2755
2756 return llama_sampler_init(
2757 /* .iface = */ &llama_sampler_penalties_i,
2758 /* .ctx = */ new llama_sampler_penalties {
2759 /* .penalty_last_n = */ penalty_last_n,
2760 /* .penalty_repeat = */ penalty_repeat,
2761 /* .penalty_freq = */ penalty_freq,
2762 /* .penalty_present = */ penalty_present,
2763 /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
2764 /* .token_count = */ {},
2765 }
2766 );
2767}
2768
2769// top-n-sigma
2770
2771struct llama_sampler_top_n_sigma {
2772 const float n;
2773};
2774
2775static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
2776 return "top-n-sigma";
2777}
2778
2779static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2780 auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
2781
2782 if (ctx->n <= 0.0f || cur_p->size <= 1) {
2783 return;
2784 }
2785
2786 // find max logit and calculate mean
2787 float max = cur_p->data[0].logit;
2788 float logits_sum = 0;
2789 size_t valid_count = 0;
2790 for (size_t i = 0; i < cur_p->size; ++i) {
2791 // Only count non-negative infinity values
2792 if (cur_p->data[i].logit != -INFINITY) {
2793 max = std::max(max, cur_p->data[i].logit);
2794 logits_sum += cur_p->data[i].logit;
2795 valid_count++;
2796 }
2797 }
2798 float mean = valid_count > 0 ? logits_sum/valid_count : 0;
2799
2800 // calculate standard deviation
2801 float acc = 0;
2802 for (size_t i = 0; i < cur_p->size; ++i) {
2803 // Skip -infinity in std calculation
2804 if (cur_p->data[i].logit != -INFINITY) {
2805 acc += pow(cur_p->data[i].logit - mean, 2);
2806 }
2807 }
2808 float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
2809
2810 // apply mask
2811 for (size_t i = 0; i < cur_p->size; ++i) {
2812 if (cur_p->data[i].logit < max - (ctx->n * std)) {
2813 cur_p->data[i].logit = -INFINITY;
2814 }
2815 }
2816
2817 llama_sampler_softmax_impl(cur_p, true);
2818}
2819
2820static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
2821 const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
2822 return llama_sampler_init_top_n_sigma(ctx->n);
2823}
2824
2825static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
2826 delete (llama_sampler_top_n_sigma *) smpl->ctx;
2827}
2828
2829static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
2830 /* .name = */ llama_sampler_top_n_sigma_name,
2831 /* .accept = */ nullptr,
2832 /* .apply = */ llama_sampler_top_n_sigma_apply,
2833 /* .reset = */ nullptr,
2834 /* .clone = */ llama_sampler_top_n_sigma_clone,
2835 /* .free = */ llama_sampler_top_n_sigma_free,
2836 /* .backend_init = */ nullptr,
2837 /* .backend_accept = */ nullptr,
2838 /* .backend_apply = */ nullptr,
2839 /* .backend_set_input = */ nullptr,
2840};
2841
2842struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
2843 const bool is_empty = (n <= 0.0f);
2844
2845 if (is_empty) {
2846 return llama_sampler_init_empty("?top-n-sigma");
2847 }
2848
2849 return llama_sampler_init(
2850 /* .iface = */ &llama_sampler_top_n_sigma_i,
2851 /* .ctx = */ new llama_sampler_top_n_sigma {
2852 /* .n = */ n,
2853 }
2854 );
2855}
2856
2857// DRY
2858
2859struct llama_sampler_dry {
2860 int32_t total_context_size;
2861
2862 const float dry_multiplier;
2863 const float dry_base;
2864 const int32_t dry_allowed_length;
2865 const int32_t dry_penalty_last_n;
2866
2867 std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
2868 std::vector<int> dry_repeat_count;
2869 std::unordered_map<llama_token, int> dry_max_token_repeat;
2870 ring_buffer<llama_token> last_tokens;
2871};
2872
2873// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
2874static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
2875 for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
2876 std::string word = vocab.detokenize({token_id}, true);
2877 if (word.find(str) != std::string::npos) {
2878 token_sequences.emplace(token_id, std::vector<llama_token>());
2879 } else {
2880 size_t word_len = word.size();
2881 size_t str_len = str.size();
2882 size_t pos = -1;
2883 while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
2884 bool match = true;
2885 size_t i;
2886 for (i = 1; i < str_len && i + pos < word_len; ++i) {
2887 if (word[pos + i] != str[i]) {
2888 match = false;
2889 break;
2890 }
2891 }
2892 if (match) {
2893 std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
2894 if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
2895 tokenization.resize(max_tail_len);
2896 }
2897
2898 // Ensure we don't already have a duplicate matching tokenization
2899 auto its = token_sequences.equal_range(token_id);
2900 bool found = false;
2901 for (auto it = its.first; it != its.second; ++it) {
2902 if (tokenization == it->second) {
2903 found = true;
2904 break;
2905 }
2906 }
2907 if (!found) {
2908 token_sequences.emplace(token_id, tokenization);
2909 }
2910 }
2911 }
2912 }
2913 }
2914}
2915
2916static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
2917 return "dry";
2918}
2919
2920static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
2921 auto * ctx = (llama_sampler_dry *) smpl->ctx;
2922 if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
2923 return;
2924 }
2925
2926 ctx->last_tokens.push_back(token);
2927}
2928
2929// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
2930static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2931 auto * ctx = (llama_sampler_dry *) smpl->ctx;
2932
2933 if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
2934 return;
2935 }
2936
2937 int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
2938 int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
2939
2940 if (last_n_repeat <= ctx->dry_allowed_length) {
2941 return;
2942 }
2943
2944 ctx->dry_repeat_count.assign(last_n_repeat, 0);
2945 ctx->dry_max_token_repeat.clear();
2946
2947 // Step 1: Look for restart sequences to limit the maximum repetition length.
2948 // Work backwards through the context looking for any token that begins a restart sequence.
2949 //
2950 // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
2951 // sequences that together comprise a restart sequence. This allows us to quickly check
2952 // whether each token is the head of a complete sequence. Most restart sequences are actually
2953 // a single token, and for these the "tail" is an empty vector.
2954 //
2955 // If the token is a "head", test all restart sequences that begin with this token
2956 // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
2957 // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
2958 // longest matching sequence (if any) is used to limit the maximum repetition length.
2959 //
2960 // Note that in the case case of a short sequence contained in a longer one, this might fail to
2961 // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
2962 // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
2963 // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
2964 //
2965 // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
2966 // have already clamped the maximum tail sequence length when generating `restart_sequences`.
2967 // With clamping, this scan is O(N) in the context length.
2968
2969 int rep_limit = last_n_repeat;
2970 for (int i = 0; i < last_n_repeat; ++i) {
2971 llama_token token = ctx->last_tokens.rat(i);
2972 auto its = ctx->dry_processed_breakers.equal_range(token);
2973 if (its.first == ctx->dry_processed_breakers.end()) {
2974 continue;
2975 }
2976 int longest_match = -1;
2977 for (auto it = its.first; it != its.second; ++it) {
2978 // Note that (*it) does not contain the head character, so seq_len will be
2979 // the restart sequence length minus 1.
2980 // In the common case of a single-token restart sequence, (*it) will be empty
2981 // and we will trivially match.
2982 int seq_len = (int)it->second.size();
2983 if (seq_len > longest_match && seq_len <= (int)i) {
2984 bool match = true;
2985 for (int offset = 0; offset < seq_len; ++offset) {
2986 // The -1 when indexing `last_tokens` is because we already matched the head.
2987 if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
2988 match = false;
2989 break;
2990 }
2991 }
2992 if (match) {
2993 longest_match = seq_len;
2994 }
2995 }
2996 }
2997 if (longest_match >= 0) {
2998 // We found a restart sequence starting `i` tokens from the end and continuing for
2999 // `longest_match` tokens.
3000 rep_limit = i - longest_match;
3001 break;
3002 }
3003 }
3004 if (rep_limit < ctx->dry_allowed_length) {
3005 return;
3006 }
3007
3008 // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
3009 // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
3010 // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
3011 //
3012 // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
3013 // https://ivanyu.me/blog/2014/10/15/z-algorithm/
3014 //
3015 // The code below is adapted from the public domain implementation by the same author here:
3016 // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
3017 //
3018 // Example:
3019 // Last N tokens: a b c c b c y a b c
3020 // Repeat counts: 0 0 3 1 0 2 0 0 0 0
3021 // ^
3022 // This `3` means that the last three tokens of the context (a b c) also appear here.
3023 //
3024 // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
3025 // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
3026 // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
3027 // ensure that the inner while loops only examine each token in the context once as the outer
3028 // for loop iterates over the context.
3029
3030 {
3031 const int last = last_n_repeat - 1;
3032
3033 int rt = 0;
3034 int lt = 0;
3035
3036 for (int k = 1; k < last_n_repeat; ++k) {
3037 if (k > rt) {
3038 // If k is outside the current Z-box, do naive computation.
3039 int n = 0;
3040 while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
3041 ++n;
3042 }
3043 ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
3044 if (n > 0) {
3045 lt = k;
3046 rt = k + n - 1;
3047 }
3048 } else {
3049 // If k is inside the current Z-box, consider two cases.
3050
3051 int p = k - lt; // Pair index.
3052 int right_part_len = rt - k + 1;
3053
3054 if (ctx->dry_repeat_count[last - p] < right_part_len) {
3055 int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
3056 ctx->dry_repeat_count[last - k] = n;
3057 } else {
3058 int i = rt + 1;
3059 while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
3060 i += 1;
3061 }
3062
3063 int n = std::min(i - k, rep_limit);
3064 ctx->dry_repeat_count[last - k] = n;
3065 lt = k;
3066 rt = i - 1;
3067 }
3068 }
3069 }
3070 }
3071
3072 // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
3073 // that would be generated by emitting each new token that would extend a sequence.
3074 //
3075 // Following the same example as above:
3076 // Last N tokens: a b c c b c y a b c
3077 // Repeat counts: 0 0 3 1 0 2 0 0 0 0
3078 //
3079 // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
3080 // c: 3 -> 4 (from `a b c` to `a b c c`)
3081 // b: 1 -> 2 (from `c` to `c b`)
3082 // y: 2 -> 3 (from `b c` to `b c y`)
3083
3084 for (int i = 0; i < last_n_repeat - 1; ++i) {
3085 int repeat_len = ctx->dry_repeat_count[i];
3086 if (repeat_len >= ctx->dry_allowed_length) {
3087 // This token ends a repeat, so the next token would continue one.
3088 // By convention, the value of `repeat_len` only includes the tokens currently
3089 // in the context, not the new token that would be added.
3090 llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
3091 // Track the maximum sequence ending in this token.
3092 const auto& it = ctx->dry_max_token_repeat.find(token);
3093 if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
3094 ctx->dry_max_token_repeat[token] = repeat_len;
3095 }
3096 }
3097 }
3098
3099 // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
3100
3101 // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
3102 // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
3103 const float FLOAT_MAX_LOG = 88.7228391f;
3104 int max_exponent = 0;
3105 if (ctx->dry_base > 1.000001f) {
3106 max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
3107 }
3108
3109 for (size_t i = 0; i < cur_p->size; ++i) {
3110 const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
3111 if (af_kvp != ctx->dry_max_token_repeat.end()) {
3112 // Check all sequence breakers starting with this token
3113 auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
3114 bool is_single_token_breaker = false;
3115
3116 for (auto it = range.first; it != range.second; ++it) {
3117 if (it->second.empty()) {
3118 is_single_token_breaker = true;
3119 break;
3120 }
3121 }
3122
3123 // Apply penalty only if it's not a single-token sequence breaker
3124 if (!is_single_token_breaker) {
3125 int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
3126 if (max_exponent > 0 && repeat_exp > max_exponent) {
3127 repeat_exp = max_exponent;
3128 }
3129 float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
3130 cur_p->data[i].logit -= penalty;
3131 }
3132 }
3133 }
3134
3135 cur_p->sorted = false;
3136}
3137
3138static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
3139 auto * ctx = (llama_sampler_dry *) smpl->ctx;
3140 ctx->last_tokens.clear();
3141 ctx->dry_repeat_count.clear();
3142 ctx->dry_max_token_repeat.clear();
3143}
3144
3145static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
3146 const auto * ctx = (llama_sampler_dry *) smpl->ctx;
3147
3148 llama_vocab dummy_vocab;
3149
3150 // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
3151 auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
3152
3153 // Copy the state, including the processed breakers
3154 {
3155 auto * result_ctx = (llama_sampler_dry *) result->ctx;
3156 result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
3157 result_ctx->dry_repeat_count = ctx->dry_repeat_count;
3158 result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
3159 result_ctx->last_tokens = ctx->last_tokens;
3160 }
3161
3162 return result;
3163}
3164
3165static void llama_sampler_dry_free(struct llama_sampler * smpl) {
3166 delete (llama_sampler_dry *) smpl->ctx;
3167}
3168
3169static struct llama_sampler_i llama_sampler_dry_i = {
3170 /* .name = */ llama_sampler_dry_name,
3171 /* .accept = */ llama_sampler_dry_accept,
3172 /* .apply = */ llama_sampler_dry_apply,
3173 /* .reset = */ llama_sampler_dry_reset,
3174 /* .clone = */ llama_sampler_dry_clone,
3175 /* .free = */ llama_sampler_dry_free,
3176 /* .backend_init = */ nullptr,
3177 /* .backend_accept = */ nullptr,
3178 /* .backend_apply = */ nullptr,
3179 /* .backend_set_input = */ nullptr,
3180};
3181
3182struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
3183 int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
3184 std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
3185 const int MAX_CHAR_LEN = 40;
3186 const int MAX_SEQ_LEN = 20;
3187
3188 const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
3189
3190 if (!dry_enabled) {
3191 return llama_sampler_init_empty("?dry");
3192 }
3193
3194 if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
3195 // Process sequence breakers
3196 for (size_t i = 0; i < num_breakers; ++i) {
3197 if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
3198 LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
3199 continue;
3200 }
3201
3202 std::string sequence_break(seq_breakers[i]);
3203 if (sequence_break.empty()) {
3204 LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
3205 continue;
3206 }
3207
3208 if (sequence_break.size() > MAX_CHAR_LEN) {
3209 LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
3210 sequence_break.resize(MAX_CHAR_LEN);
3211 }
3212
3213 get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
3214 }
3215 }
3216
3217 return llama_sampler_init(
3218 /* .iface = */ &llama_sampler_dry_i,
3219 /* .ctx = */ new llama_sampler_dry {
3220 /* .total_context_size = */ n_ctx_train,
3221 /* .dry_multiplier = */ dry_multiplier,
3222 /* .dry_base = */ dry_base,
3223 /* .dry_allowed_length = */ dry_allowed_length,
3224 /* .dry_penalty_last_n = */ dry_penalty_last_n,
3225 /* .dry_processed_breakers = */ std::move(processed_breakers),
3226 /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
3227 /* .dry_max_token_repeat = */ {},
3228 /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
3229 }
3230 );
3231}
3232
3233// wrapper for test-sampling.cpp
3234struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
3235 llama_vocab dummy_vocab;
3236 auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
3237 auto * ctx = (llama_sampler_dry *) result->ctx;
3238
3239 // Process the token-based sequence breakers
3240 ctx->dry_processed_breakers.clear();
3241 if (seq_breakers.empty()) {
3242 LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
3243 } else {
3244 for (const auto& breaker : seq_breakers) {
3245 if (breaker.empty()) {
3246 LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
3247 continue;
3248 }
3249 llama_token head_token = breaker[0];
3250 std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
3251 ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
3252 }
3253
3254 if (ctx->dry_processed_breakers.empty()) {
3255 LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
3256 }
3257 }
3258
3259 return result;
3260}
3261
3262// adaptive-p sampler state
3263//
3264// maintains an exponential moving average of the *ORIGINAL* probabilities
3265// of selected tokens, used to compute an adapted target at each sampling step.
3266//
3267// see llama.h for a full description of the sampler
3268//
3269// ref: https://github.com/ggml-org/llama.cpp/pull/17927
3270//
3271struct llama_sampler_adaptive_p {
3272 const float target; // target probability (0.0 - 1.0; negative = disabled)
3273 const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
3274 const uint32_t seed; // original RNG seed
3275 uint32_t seed_cur; // actual RNG seed
3276 std::mt19937 rng; // RNG state
3277 float weighted_sum; // sum(p_i * decay^i)
3278 float total_weight; // sum(decay^i), converges to 1/(1-decay)
3279 std::vector<float> original_probs; // pre-transform probs, cached for EMA update
3280 llama_token pending_token_id; // token ID of selected token
3281 int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs
3282};
3283
3284// adaptive probability transformation constants
3285static constexpr float DISTRIBUTION_WIDTH = 0.3f;
3286static constexpr float PEAK_LOGIT_VALUE = 5.0f;
3287static constexpr float SHARPNESS = 10.0f;
3288static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
3289
3290static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
3291 return "adaptive-p";
3292}
3293
3294static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3295 auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3296
3297 llama_sampler_softmax_impl(cur_p, false);
3298
3299 if (ctx->target < 0.0f) {
3300 // at negative target values, adaptive-p is no-op
3301 // we simply sample from the existing distribution
3302 cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
3303 return;
3304 }
3305
3306 // store the original probabilities
3307 ctx->original_probs.resize(cur_p->size);
3308 for (size_t i = 0; i < cur_p->size; ++i) {
3309 ctx->original_probs[i] = cur_p->data[i].p;
3310 }
3311
3312 // using the EMA, compute the adapted target probability for the current sampling step
3313 auto target = std::clamp(ctx->target, 0.0f, 1.0f);
3314 float adapted_target = std::clamp(
3315 ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
3316 0.0f, 1.0f
3317 );
3318
3319 // adaptive probability transform
3320 //
3321 // quadratic near target for fine differentiation, transitioning to linear decay in the
3322 // tails. unbounded negative logits ensure proper suppression of far-from-target tokens
3323 // after the softmax.
3324 //
3325 for (size_t i = 0; i < cur_p->size; ++i) {
3326 if (cur_p->data[i].logit == -INFINITY) {
3327 // don't transform logits that are -INFINITY
3328 // (as masked out by e.g. min-p and top-p when using backend sampling)
3329 continue;
3330 }
3331 float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
3332 cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
3333 }
3334
3335 // softmax and sample from the transformed distribution
3336 llama_sampler_softmax_impl(cur_p, false);
3337 const int idx = llama_sample_dist(cur_p, ctx->rng);
3338 cur_p->selected = idx;
3339
3340 // store the selected token ID for acceptance later
3341 ctx->pending_token_id = cur_p->data[idx].id;
3342 ctx->pending_token_idx = idx;
3343}
3344
3345static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) {
3346 auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3347 if (ctx->pending_token_id == token) {
3348 GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL);
3349 GGML_ASSERT(ctx->pending_token_idx != -1);
3350 // update EMA with the original probability of the selected token
3351 ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum;
3352 ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
3353 }
3354 ctx->pending_token_id = LLAMA_TOKEN_NULL;
3355 ctx->pending_token_idx = -1;
3356}
3357
3358static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
3359 auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3360 // ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
3361 // original_probs is completely overwritten on every call to _apply.
3362 // so we only need to reset the EMA state and pending token.
3363 ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
3364 ctx->total_weight = 1.0f / (1.0f - ctx->decay);
3365 ctx->pending_token_id = LLAMA_TOKEN_NULL;
3366 ctx->pending_token_idx = -1;
3367 ctx->seed_cur = get_rng_seed(ctx->seed);
3368 ctx->rng.seed(ctx->seed_cur);
3369}
3370
3371static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
3372 const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
3373 auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
3374 auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
3375
3376 // copy everything (target, decay, seed, and RNG are already set)
3377 result_ctx->weighted_sum = ctx->weighted_sum;
3378 result_ctx->total_weight = ctx->total_weight;
3379 result_ctx->pending_token_id = ctx->pending_token_id;
3380 result_ctx->pending_token_idx = ctx->pending_token_idx;
3381
3382 return result;
3383}
3384
3385static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
3386 delete (llama_sampler_adaptive_p *) smpl->ctx;
3387}
3388
3389static struct llama_sampler_i llama_sampler_adaptive_p_i = {
3390 /* .name = */ llama_sampler_adaptive_p_name,
3391 /* .accept = */ llama_sampler_adaptive_p_accept,
3392 /* .apply = */ llama_sampler_adaptive_p_apply,
3393 /* .reset = */ llama_sampler_adaptive_p_reset,
3394 /* .clone = */ llama_sampler_adaptive_p_clone,
3395 /* .free = */ llama_sampler_adaptive_p_free,
3396 /* .backend_init = */ nullptr,
3397 /* .backend_accept = */ nullptr,
3398 /* .backend_apply = */ nullptr,
3399 /* .backend_set_input = */ nullptr,
3400};
3401
3402struct llama_sampler * llama_sampler_init_adaptive_p(
3403 float target,
3404 float decay,
3405 uint32_t seed
3406) {
3407 auto seed_cur = get_rng_seed(seed);
3408 float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
3409 return llama_sampler_init(
3410 /* .iface = */ &llama_sampler_adaptive_p_i,
3411 /* .ctx = */ new llama_sampler_adaptive_p {
3412 /* .target = */ target,
3413 /* .decay = */ clamped_decay,
3414 /* .seed = */ seed,
3415 /* .seed_cur = */ seed_cur,
3416 /* .rng = */ std::mt19937(seed_cur),
3417 /* .weighted_sum = */ target / (1.0f - clamped_decay),
3418 /* .total_weight = */ 1.0f / (1.0f - clamped_decay),
3419 /* .original_probs = */ {},
3420 /* .pending_token_id = */ LLAMA_TOKEN_NULL,
3421 /* .pending_token_idx = */ -1
3422 }
3423 );
3424}
3425
3426// logit-bias
3427
3428struct llama_sampler_logit_bias : public llama_sampler_backend {
3429 const int32_t n_vocab;
3430
3431 const std::vector<llama_logit_bias> logit_bias;
3432
3433 std::vector<llama_logit_bias> to_search;
3434
3435 struct ggml_tensor * inp_logit_bias;
3436 struct ggml_tensor * inp_logit_idxs;
3437};
3438
3439static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
3440 auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3441 return ctx->get_name();
3442}
3443
3444static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3445 auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3446
3447 if (ctx->logit_bias.empty()) {
3448 return;
3449 }
3450
3451 ctx->to_search.clear();
3452
3453 // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
3454 for (const auto & lb : ctx->logit_bias) {
3455 if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
3456 cur_p->data[lb.token].logit += lb.bias;
3457 } else {
3458 ctx->to_search.push_back(lb);
3459 }
3460 }
3461
3462 if (ctx->to_search.empty()) {
3463 return;
3464 }
3465
3466 // search for the remaining candidates that were not found in the previous step
3467 for (size_t i = 0; i < cur_p->size; ++i) {
3468 for (const auto & lb : ctx->to_search) {
3469 if (cur_p->data[i].id == lb.token) {
3470 cur_p->data[i].logit += lb.bias;
3471 break;
3472 }
3473 }
3474 }
3475}
3476
3477static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
3478 const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
3479 return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
3480}
3481
3482static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
3483 delete (llama_sampler_logit_bias *) smpl->ctx;
3484}
3485
3486static void llama_sampler_logit_bias_backend_apply(
3487 struct llama_sampler * smpl,
3488 struct ggml_context * ctx,
3489 struct ggml_cgraph * gf,
3490 struct llama_sampler_data * data) {
3491 GGML_UNUSED(gf);
3492 GGML_UNUSED(ctx);
3493
3494 auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3495 if (sctx->logit_bias.empty()) {
3496 return;
3497 }
3498
3499 const size_t n = sctx->logit_bias.size();
3500
3501 sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
3502 ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3503 ggml_set_input(sctx->inp_logit_bias);
3504
3505 sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
3506 ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3507 ggml_set_input(sctx->inp_logit_idxs);
3508
3509 ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
3510
3511 cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
3512 cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
3513 cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
3514
3515 data->logits = ggml_add(ctx, data->logits, cur);
3516}
3517
3518static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
3519 auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3520 if (sctx->logit_bias.empty()) {
3521 return;
3522 }
3523
3524 GGML_ASSERT(sctx->inp_logit_bias != nullptr);
3525 GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
3526
3527 const size_t n = sctx->logit_bias.size();
3528
3529 std::vector<float> data_logit_bias(n, 0.0f);
3530 std::vector<int32_t> data_logit_idxs(n, 0);
3531 for (size_t i = 0; i < n; ++i) {
3532 const auto & lb = sctx->logit_bias[i];
3533 GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
3534 data_logit_bias[i] = lb.bias;
3535 data_logit_idxs[i] = lb.token;
3536 }
3537
3538 ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
3539 ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
3540}
3541
3542static bool llama_sampler_logit_bias_backend_init(
3543 struct llama_sampler * smpl,
3544 ggml_backend_buffer_type_t buft) {
3545 GGML_UNUSED(buft);
3546
3547 auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3548
3549 sctx->init(true);
3550
3551 if (sctx->logit_bias.empty()) {
3552 return true;
3553 }
3554
3555 return true;
3556}
3557
3558static struct llama_sampler_i llama_sampler_logit_bias_i = {
3559 /* .name = */ llama_sampler_logit_bias_name,
3560 /* .accept = */ nullptr,
3561 /* .apply = */ llama_sampler_logit_bias_apply,
3562 /* .reset = */ nullptr,
3563 /* .clone = */ llama_sampler_logit_bias_clone,
3564 /* .free = */ llama_sampler_logit_bias_free,
3565 /* .backend_init = */ llama_sampler_logit_bias_backend_init,
3566 /* .backend_accept = */ nullptr,
3567 /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
3568 /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
3569};
3570
3571struct llama_sampler * llama_sampler_init_logit_bias(
3572 int32_t n_vocab,
3573 int32_t n_logit_bias,
3574 const llama_logit_bias * logit_bias) {
3575 const bool is_empty = n_logit_bias <= 0;
3576
3577 if (is_empty) {
3578 return llama_sampler_init_empty("?logit-bias");
3579 }
3580
3581 return llama_sampler_init(
3582 /* .iface = */ &llama_sampler_logit_bias_i,
3583 /* .ctx = */ new llama_sampler_logit_bias {
3584 ("logit-bias"),
3585 /* .n_vocab = */ n_vocab,
3586 /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
3587 /* .to_search = */ {},
3588 /* .inp_logit_bias = */ nullptr,
3589 /* .inp_logit_idxs = */ nullptr,
3590 }
3591 );
3592}
3593
3594// infill
3595
3596//#define GGML_DEBUG_SAMPLER_INFILL
3597
3598struct llama_sampler_infill {
3599 const struct llama_vocab * vocab;
3600
3601 std::vector<char> buf0;
3602 std::vector<char> buf1;
3603};
3604
3605static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
3606 return "infill";
3607}
3608
3609static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3610 auto * ctx = (llama_sampler_infill *) smpl->ctx;
3611
3612 llama_sampler_softmax_impl(cur_p, true);
3613
3614#if defined(GGML_DEBUG_SAMPLER_INFILL)
3615#define LOG_DBG_CUR LLAMA_LOG_DEBUG
3616#else
3617#define LOG_DBG_CUR(...)
3618#endif
3619
3620 for (size_t i = 0; i < cur_p->size; ++i) {
3621 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3622 }
3623
3624 float p_txt_sum = 0.0f;
3625 float p_eog_sum = 0.0f;
3626
3627 for (size_t i = 0; i < cur_p->size; ++i) {
3628 if (ctx->vocab->is_eog(cur_p->data[i].id)) {
3629 p_eog_sum += cur_p->data[i].p;
3630 } else {
3631 p_txt_sum += cur_p->data[i].p;
3632 }
3633 }
3634
3635 const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
3636
3637 LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
3638
3639 if (3*p_eog_sum*cur_p->size > p_txt_sum) {
3640 LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
3641
3642 // keep just the EOG tokens
3643 const auto size_org = cur_p->size;
3644
3645 cur_p->size = 0;
3646
3647 float p_sum = 0.0f;
3648
3649 for (size_t i = 0; i < size_org; ++i) {
3650 if (ctx->vocab->is_eog(cur_p->data[i].id)) {
3651 p_sum += cur_p->data[i].p;
3652
3653 cur_p->data[cur_p->size++] = cur_p->data[i];
3654 }
3655 }
3656
3657 // normalize probs
3658 for (size_t i = 0; i < cur_p->size; ++i) {
3659 cur_p->data[i].p /= p_sum;
3660 }
3661
3662 return;
3663 }
3664
3665 size_t n_combined = 0; GGML_UNUSED(n_combined);
3666
3667 // combine tokens with common prefix
3668 for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
3669 for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
3670 if (cur_p->data[i0].logit == -INFINITY) {
3671 break;
3672 }
3673
3674 if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
3675 continue;
3676 }
3677
3678 int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
3679 if (len0 < 0) {
3680 ctx->buf0.resize(len0);
3681 len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
3682 assert(len0 > 0);
3683 }
3684
3685 int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
3686 if (len1 < 0) {
3687 ctx->buf1.resize(len1);
3688 len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
3689 assert(len1 > 0);
3690 }
3691
3692 // token i0 is a prefix of token i1
3693 if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
3694 int dst = i0;
3695 int src = i1;
3696
3697 // merge into the token with higher probability
3698 if (cur_p->data[i1].p > cur_p->data[i0].p) {
3699 std::swap(dst, src);
3700 }
3701
3702 cur_p->data[dst].p += cur_p->data[src].p;
3703 cur_p->data[src].logit = -INFINITY;
3704 cur_p->data[src].p = 0.0f;
3705
3706 n_combined++;
3707 }
3708 }
3709 }
3710
3711 size_t n_non_eog = 0;
3712
3713 size_t size_org = cur_p->size;
3714
3715 float p_sum = 0.0f;
3716 float thold = 0.2f;
3717
3718 cur_p->size = 0;
3719
3720 LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
3721
3722 for (size_t i = 0; i < size_org; ++i) {
3723 const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
3724
3725 if (cur_p->data[i].p < thold && !is_eog) {
3726 continue;
3727 }
3728
3729 if (!is_eog) {
3730 ++n_non_eog;
3731 }
3732
3733 p_sum += cur_p->data[i].p;
3734
3735 // keep this token
3736 cur_p->data[cur_p->size++] = cur_p->data[i];
3737 }
3738
3739 LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
3740
3741 // if no non-EOG tokens are left -> reduce cur_p to single EOT token
3742 if (n_non_eog == 0) {
3743 cur_p->size = 1;
3744 cur_p->data[0].id = ctx->vocab->token_eot();
3745 if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
3746 cur_p->data[0].id = ctx->vocab->token_eos();
3747 }
3748 cur_p->data[0].logit = 1.0f;
3749
3750 GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
3751
3752 return;
3753 }
3754
3755 // normalize probs
3756 for (size_t i = 0; i < cur_p->size; ++i) {
3757 cur_p->data[i].p /= p_sum;
3758
3759 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3760 }
3761
3762 size_org = cur_p->size;
3763 p_sum = 0.0f;
3764 thold = 1.0/(n_non_eog + 1);
3765
3766 cur_p->size = 0;
3767
3768 LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
3769
3770 for (size_t i = 0; i < size_org; ++i) {
3771 const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
3772
3773 if (cur_p->data[i].p < thold && !is_eog) {
3774 continue;
3775 }
3776
3777 p_sum += cur_p->data[i].p;
3778
3779 cur_p->data[cur_p->size++] = cur_p->data[i];
3780 }
3781
3782 // normalize probs
3783 for (size_t i = 0; i < cur_p->size; ++i) {
3784 cur_p->data[i].p /= p_sum;
3785
3786 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3787 }
3788
3789#undef LOG_DBG_CUR
3790}
3791
3792static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
3793 const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
3794 return llama_sampler_init_infill(ctx->vocab);
3795}
3796
3797static void llama_sampler_infill_free(struct llama_sampler * smpl) {
3798 delete (llama_sampler_infill *) smpl->ctx;
3799}
3800
3801static struct llama_sampler_i llama_sampler_infill_i = {
3802 /* .name = */ llama_sampler_infill_name,
3803 /* .accept = */ nullptr,
3804 /* .apply = */ llama_sampler_infill_apply,
3805 /* .reset = */ nullptr,
3806 /* .clone = */ llama_sampler_infill_clone,
3807 /* .free = */ llama_sampler_infill_free,
3808 /* .backend_apply = */ nullptr,
3809 /* .backend_accept = */ nullptr,
3810 /* .backend_set_input = */ nullptr,
3811 /* .backend_init = */ nullptr,
3812};
3813
3814struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
3815 return llama_sampler_init(
3816 /* .iface = */ &llama_sampler_infill_i,
3817 /* .ctx = */ new llama_sampler_infill {
3818 /* .vocab = */ vocab,
3819 /* .buf0 = */ std::vector<char>(512),
3820 /* .buf1 = */ std::vector<char>(512),
3821 }
3822 );
3823}
3824
3825// utils
3826
3827uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
3828 if (smpl->iface == &llama_sampler_dist_i) {
3829 return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
3830 }
3831
3832 if (smpl->iface == &llama_sampler_mirostat_i) {
3833 return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
3834 }
3835
3836 if (smpl->iface == &llama_sampler_mirostat_v2_i) {
3837 return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
3838 }
3839
3840 if (smpl->iface == &llama_sampler_chain_i) {
3841 const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
3842 for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
3843 const uint32_t seed = llama_sampler_get_seed(it->ptr);
3844 if (seed != LLAMA_DEFAULT_SEED) {
3845 return seed;
3846 }
3847 }
3848 }
3849
3850 return LLAMA_DEFAULT_SEED;
3851}
3852
3853// perf
3854
3855struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
3856 struct llama_perf_sampler_data data = {};
3857
3858 if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
3859 GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
3860 }
3861
3862 const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
3863
3864 data.t_sample_ms = 1e-3 * ctx->t_sample_us;
3865 data.n_sample = std::max(0, ctx->n_sample);
3866
3867 return data;
3868}
3869
3870void llama_perf_sampler_print(const struct llama_sampler * chain) {
3871 const auto data = llama_perf_sampler(chain);
3872
3873 LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
3874}
3875
3876void llama_perf_sampler_reset(struct llama_sampler * chain) {
3877 if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
3878 GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
3879 }
3880
3881 auto * ctx = (struct llama_sampler_chain *) chain->ctx;
3882
3883 ctx->t_sample_us = 0;
3884 ctx->n_sample = 0;
3885}