1#include "ggml.h"
2#include "llama.h"
3#include "llama-cpp.h"
4#include "get-model.h"
5#include "common.h"
6
7#ifdef NDEBUG
8#undef NDEBUG
9#endif
10
11#include <algorithm>
12#include <cstdlib>
13#include <cstring>
14#include <fstream>
15#include <map>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20struct test_args {
21 std::string model;
22 std::string test;
23 std::string device = "auto";
24};
25
26struct test_params {
27 llama_model_ptr model;
28};
29
30static llama_model_ptr load_model(const test_args & args) {
31 auto mparams = llama_model_default_params();
32
33 ggml_backend_dev_t devs[2] = { nullptr, nullptr };
34
35 if (args.device != "auto") {
36 if (args.device == "gpu") {
37 devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
38
39 if (devs[0] == nullptr) {
40 fprintf(stderr, "Error: GPU requested but not available\n");
41 return nullptr;
42 }
43
44 mparams.n_gpu_layers = 999;
45 } else if (args.device == "cpu") {
46 devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
47
48 mparams.n_gpu_layers = 0;
49 } else {
50 fprintf(stderr, "Error: invalid device '%s'\n", args.device.c_str());
51 return nullptr;
52 }
53
54 mparams.devices = devs;
55
56 fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
57 }
58
59 llama_model_ptr res;
60
61 res.reset(llama_model_load_from_file(args.model.c_str(), mparams));
62
63 if (!res) {
64 fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model.c_str());
65 return nullptr;
66 }
67
68 return res;
69}
70
71struct test_context {
72 llama_context_ptr ctx;
73
74 int n_vocab = 0;
75
76 const llama_vocab * vocab = nullptr;
77
78 std::unordered_map<llama_seq_id, int32_t> seq_positions;
79 std::unordered_map<llama_seq_id, int32_t> last_batch_info;
80
81 test_context(const test_params & params, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
82 auto * model = params.model.get();
83
84 GGML_ASSERT(model);
85 GGML_ASSERT(!ctx);
86
87 llama_context_params cparams = llama_context_default_params();
88 cparams.n_ctx = 512;
89 cparams.n_batch = 512;
90 cparams.samplers = configs.data();
91 cparams.n_samplers = configs.size();
92
93 // If n_seq_max is not specified, calculate it from configs
94 if (n_seq_max < 0) {
95 int32_t max_seq_id = 0;
96 for (const auto & config : configs) {
97 max_seq_id = std::max(config.seq_id, max_seq_id);
98 }
99 cparams.n_seq_max = max_seq_id + 1;
100 } else {
101 cparams.n_seq_max = n_seq_max;
102 }
103
104 ctx.reset(llama_init_from_model(model, cparams));
105 if (!ctx) {
106 throw std::runtime_error("failed to create context");
107 }
108
109 llama_set_warmup(ctx.get(), false);
110
111 vocab = llama_model_get_vocab(model);
112 n_vocab = llama_vocab_n_tokens(vocab);
113 }
114
115 bool decode(const std::map<llama_seq_id, std::string> & prompts) {
116 GGML_ASSERT(ctx);
117
118 last_batch_info.clear();
119 llama_batch batch = llama_batch_init(512, 0, prompts.size());
120
121 for (const auto & [seq_id, prompt] : prompts) {
122 std::vector<llama_token> tokens;
123 tokens.push_back(llama_vocab_bos(vocab));
124
125 std::vector<llama_token> prompt_tokens(32);
126 int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
127 prompt_tokens.data(), prompt_tokens.size(),
128 false, false);
129 if (n_tokens < 0) {
130 fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
131 llama_batch_free(batch);
132 return false;
133 }
134
135 for (int i = 0; i < n_tokens; i++) {
136 tokens.push_back(prompt_tokens[i]);
137 }
138
139 if (seq_positions.find(seq_id) == seq_positions.end()) {
140 seq_positions[seq_id] = 0;
141 }
142
143 int32_t start_pos = seq_positions[seq_id];
144 for (size_t i = 0; i < tokens.size(); i++) {
145 common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
146 }
147
148 seq_positions[seq_id] = start_pos + tokens.size();
149 }
150
151
152 printf("Batch contents:\n");
153 printf("n_tokens: %d\n", batch.n_tokens);
154 for (int i = 0; i < batch.n_tokens; i++) {
155 printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
156
157 for (int j = 0; j < batch.n_seq_id[i]; j++) {
158 printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
159 }
160 printf("], logits=%d\n", batch.logits[i]);
161 }
162
163 if (llama_decode(ctx.get(), batch) != 0) {
164 fprintf(stderr, "Warning: llama_decode failed\n");
165 llama_batch_free(batch);
166 return false;
167 }
168
169 // Build mapping from seq id to batch token idx
170 for (int i = 0; i < batch.n_tokens; i++) {
171 if (batch.logits[i]) {
172 llama_seq_id seq_id = batch.seq_id[i][0];
173 last_batch_info[seq_id] = i;
174 }
175 }
176
177 llama_batch_free(batch);
178 return true;
179 }
180
181 int32_t idx_for_seq(llama_seq_id seq_id) {
182 auto it = last_batch_info.find(seq_id);
183 if (it == last_batch_info.end()) {
184 fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id);
185 return -1;
186 }
187 return it->second;
188 }
189
190 void update_batch_info(const llama_batch & batch) {
191 last_batch_info.clear();
192 for (int i = 0; i < batch.n_tokens; i++) {
193 if (batch.logits[i]) {
194 llama_seq_id cur_seq = batch.seq_id[i][0];
195 last_batch_info[cur_seq] = i;
196 }
197 }
198 }
199
200 bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
201 GGML_ASSERT(ctx);
202
203 llama_batch batch = llama_batch_init(1, 0, 1);
204 int32_t pos = seq_positions[seq_id];
205 common_batch_add(batch, token, pos, { seq_id }, true);
206
207 if (llama_decode(ctx.get(), batch) != 0) {
208 fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
209 llama_batch_free(batch);
210 return false;
211 }
212
213 update_batch_info(batch);
214
215 seq_positions[seq_id]++;
216 llama_batch_free(batch);
217
218 return true;
219 }
220
221 bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
222 GGML_ASSERT(ctx);
223
224 llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
225
226 for (const auto & [seq_id, token] : seq_tokens) {
227 int32_t pos = seq_positions[seq_id];
228 common_batch_add(batch, token, pos, { seq_id }, true);
229 }
230
231 if (llama_decode(ctx.get(), batch) != 0) {
232 fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
233 llama_batch_free(batch);
234 return false;
235 }
236
237 for (const auto & [seq_id, _] : seq_tokens) {
238 seq_positions[seq_id]++;
239 }
240
241 update_batch_info(batch);
242
243 llama_batch_free(batch);
244
245 return true;
246 }
247
248 std::string token_to_piece(llama_token token, bool special) const {
249 std::string piece;
250 piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
251 const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
252 if (n_chars < 0) {
253 piece.resize(-n_chars);
254 int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
255 GGML_ASSERT(check == -n_chars);
256 } else {
257 piece.resize(n_chars);
258 }
259
260 return piece;
261 }
262};
263
264static void test_backend_greedy_sampling(const test_params & params) {
265 const int seq_id = 0;
266
267 struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
268 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params));
269
270 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
271 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
272
273 test_context test_ctx(params, backend_sampler_configs);
274
275 if (!test_ctx.decode({{seq_id, "Some"}})) {
276 GGML_ASSERT(false && "Failed to decode token");
277 }
278
279 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
280
281 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
282 printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
283 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
284
285 token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
286 printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
287 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
288
289 for (int i = 0; i < 10; i++) {
290 int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
291 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
292 printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
293 if (!test_ctx.decode_token(token, 0)) {
294 GGML_ASSERT(false && "Failed to decode token");
295 }
296 }
297}
298
299static void test_backend_top_k_sampling(const test_params & params) {
300 const int seq_id = 0;
301 const int32_t k = 8;
302 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
303 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
304 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
305 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
306
307 test_context test_ctx(params, backend_sampler_configs);
308
309 if (!test_ctx.decode({{seq_id, "Hello"}})) {
310 GGML_ASSERT(false && "Failed to decode token");
311 }
312
313 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
314
315 float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
316 uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
317 for (size_t i = 0; i < n_logits; ++i) {
318 printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
319 }
320
321 llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
322 uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
323 for (size_t i = 0; i < n_candidates; ++i) {
324 printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
325 test_ctx.token_to_piece(candidates[i], false).c_str());
326 }
327
328 // Sample using CPU sampler for verification that it is possible to do hybrid
329 // sampling, first top_k on the backend and then dist on the CPU.
330 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
331 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
332 GGML_ASSERT(chain->iface->backend_apply != nullptr);
333
334 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
335 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
336 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
337
338 printf("backend top-k hybrid sampling test PASSED\n");
339}
340
341static void test_backend_temp_sampling(const test_params & params) {
342 {
343 const float temp_0 = 0.8f;
344 struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
345 llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0));
346 llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0));
347
348 const float temp_1 = 0.1f;
349 struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
350 llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1));
351 llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1));
352
353 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
354 { 0, backend_sampler_chain_0.get() },
355 { 1, backend_sampler_chain_1.get() }
356 };
357
358 test_context test_ctx(params, backend_sampler_configs);
359
360 if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
361 GGML_ASSERT(false && "Failed to decode token");
362 }
363
364 // Verfify sequence 0
365 {
366 int32_t batch_idx = test_ctx.idx_for_seq(0);
367 int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
368 GGML_ASSERT(n_logits == test_ctx.n_vocab);
369
370 // Sample from sequence 0 using CPU sampler
371 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
372 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
373 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
374
375 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
376 const std::string token_str = test_ctx.token_to_piece(token, false);
377 printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
378 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
379 }
380
381
382 // Verfify sequence 1
383 {
384 int32_t batch_idx = test_ctx.idx_for_seq(1);
385
386 // Sample from sequence 1 using CPU sampler
387 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
388 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
389 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
390
391 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
392 const std::string token_str = test_ctx.token_to_piece(token, false);
393 printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
394 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
395 }
396 }
397
398 // lambda to testing non-positive temperature values.
399 auto test_argmax_temp = [&](float temp) {
400 printf("\nTesting temperature = %.1f\n", temp);
401
402 int seq_id = 0;
403 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
404 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
405 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp));
406
407 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
408 { seq_id, backend_sampler_chain.get() },
409 };
410
411 test_context test_ctx(params, backend_sampler_configs);
412
413 if (!test_ctx.decode({{seq_id, "Once"}})) {
414 GGML_ASSERT(false && "Failed to decode token");
415 }
416
417 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
418
419 uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
420 GGML_ASSERT(n_logits == 1);
421 };
422
423 test_argmax_temp(0.0f);
424 test_argmax_temp(-1.0f);
425
426 printf("backend temp sampling test PASSED\n");
427}
428
429static void test_backend_temp_ext_sampling(const test_params & params) {
430 {
431 int seq_id = 0;
432 const float temp = 0.8f;
433 const float delta = 0.5f;
434 const float exponent = 1.5f;
435 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
436 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
437 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
438
439 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
440 { seq_id, backend_sampler_chain.get() },
441 };
442
443 test_context test_ctx(params, backend_sampler_configs);
444
445 if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
446 GGML_ASSERT(false && "Failed to decode token");
447 }
448
449 // Verify sequence 0
450 {
451 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
452 int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
453 GGML_ASSERT(n_logits == test_ctx.n_vocab);
454 }
455 }
456
457 // lambda to testing non-positive temp/delta/exponent values.
458 auto test_argmax_temp = [&](float temp, float delta, float exponent) {
459 printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
460
461 int seq_id = 0;
462 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
463 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
464 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
465
466 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
467 { seq_id, backend_sampler_chain.get() },
468 };
469
470 test_context test_ctx(params, backend_sampler_configs);
471
472 if (!test_ctx.decode({{seq_id, "Once"}})) {
473 GGML_ASSERT(false && "Failed to decode token");
474 }
475
476 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
477
478 uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
479
480 if (temp <= 0.0f && delta >= 0.0f) {
481 GGML_ASSERT(n_logits == 1);
482 } else {
483 GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
484 }
485 };
486
487 test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
488 test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
489 test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
490
491 printf("backend temp_ext sampling test PASSED\n");
492}
493
494static void test_backend_min_p_sampling(const test_params & params) {
495 const int seq_id = 0;
496 const float p = 0.1;
497 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
498 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
499 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
500 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
501
502 test_context test_ctx(params, backend_sampler_configs);
503
504 if (!test_ctx.decode({{seq_id, "Hello"}})) {
505 GGML_ASSERT(false && "Failed to decode token");
506 }
507
508 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
509
510 float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
511 uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
512
513 // Print the logits that are above the min-p threshold
514 std::vector<float> filtered_logits;
515 for (size_t i = 0; i < n_logits; ++i) {
516 if (logits[i] > -1e9f) {
517 filtered_logits.push_back(logits[i]);
518 //printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
519 }
520 }
521 GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
522
523 // Sample using CPU sampler for verification to inspect they are reasonable
524 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
525 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
526 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
527
528 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
529 const std::string token_str = test_ctx.token_to_piece(token, false);
530 printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
531 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
532
533 // Decode and sampler 10 more tokens
534 for (int i = 0; i < 10; i++) {
535 int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
536 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
537 printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
538 if (!test_ctx.decode_token(token, 0)) {
539 GGML_ASSERT(false && "Failed to decode token");
540 }
541 }
542
543 printf("min-p sampling test PASSED\n");
544}
545
546static void test_backend_top_p_sampling(const test_params & params) {
547 const int seq_id = 0;
548 const float p = 0.9;
549 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
550 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
551 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
552 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
553
554 test_context test_ctx(params, backend_sampler_configs);
555
556 if (!test_ctx.decode({{seq_id, "Hello"}})) {
557 return;
558 }
559
560 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
561
562 float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
563 uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
564
565 // Print the logits that are above the min-p threshold
566 std::vector<float> filtered_logits;
567 for (size_t i = 0; i < n_logits; ++i) {
568 if (logits[i] > -1e9f) {
569 filtered_logits.push_back(logits[i]);
570 }
571 }
572 GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
573 GGML_ASSERT(filtered_logits.size() > 0);
574
575 // Sample using CPU sampler for verification to inspect they are reasonable
576 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
577 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
578 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
579
580 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
581 const std::string token_str = test_ctx.token_to_piece(token, false);
582 printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
583 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
584
585 // Decode and sampler 10 more tokens
586 for (int i = 0; i < 10; i++) {
587 int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
588 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
589 printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
590 test_ctx.decode_token(token, 0);
591 }
592
593 printf("top-p sampling test PASSED\n");
594}
595
596static void test_backend_multi_sequence_sampling(const test_params & params) {
597 struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
598 llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
599 llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
600
601 struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
602 llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
603 llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f));
604 llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy());
605
606 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
607 { 0, sampler_chain_0.get() },
608 { 1, sampler_chain_1.get() }
609 };
610
611 test_context test_ctx(params, backend_sampler_configs);
612
613 std::map<llama_seq_id, std::string> prompts = {
614 {0, "Hello"},
615 {1, "Some"}
616 };
617
618 if (!test_ctx.decode(prompts)) {
619 GGML_ASSERT(false && "Failed to decode token");
620 }
621
622 // Verfiy sequence 0
623 {
624 int32_t batch_idx = test_ctx.idx_for_seq(0);
625 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
626 const std::string token_str = test_ctx.token_to_piece(token, false);
627 printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
628 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
629 }
630
631 // Verify sequence 1
632 {
633 int32_t batch_idx= test_ctx.idx_for_seq(1);
634 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
635 const std::string token_str = test_ctx.token_to_piece(token, false);
636 printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
637 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
638 }
639
640 // Generate tokens for each sequence
641 printf("\nMulti-sequence generation:\n");
642 for (int step = 0; step < 4; step++) {
643 std::map<llama_seq_id, llama_token> tokens;
644
645 for (llama_seq_id seq_id : {0, 1}) {
646 int32_t idx = test_ctx.idx_for_seq(seq_id);
647 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
648 const std::string token_str = test_ctx.token_to_piece(token, false);
649 printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
650 tokens[seq_id] = token;
651 }
652
653 // Decode all tokens in a single batch
654 if (!test_ctx.decode_tokens(tokens)) {
655 GGML_ASSERT(false && "Failed to decode token");
656 }
657 }
658
659 printf("backend multi-sequence sampling test PASSED\n");
660}
661
662static void test_backend_dist_sampling(const test_params & params) {
663 const int seq_id = 189;
664 const int32_t seed = 88;
665
666 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
667 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
668 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
669 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
670
671 test_context test_ctx(params, backend_sampler_configs);
672
673 if (!test_ctx.decode({{seq_id, "Some"}})) {
674 GGML_ASSERT(false && "Failed to decode token");
675 }
676
677 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
678 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
679 printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
680 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
681 //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
682
683 token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
684 printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
685 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
686
687 printf("backend dist sampling test PASSED\n");
688}
689
690static void test_backend_dist_sampling_and_cpu(const test_params & params) {
691 const int seq_id = 0;
692 const int32_t seed = 88;
693
694 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
695 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
696 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
697 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
698
699 test_context test_ctx(params, backend_sampler_configs);
700
701 if (!test_ctx.decode({{seq_id, "Some"}})) {
702 GGML_ASSERT(false && "Failed to decode token");
703 }
704
705 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
706
707 // Sample using CPU sampler
708 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
709 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
710 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
711
712 llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
713 llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
714 printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
715 GGML_ASSERT(backend_token == cpu_token);
716
717 printf("backend dist & cpu sampling test PASSED\n");
718}
719
720static void test_backend_logit_bias_sampling(const test_params & params) {
721 const auto * model = params.model.get();
722 const auto * vocab = llama_model_get_vocab(model);
723
724 const int seq_id = 0;
725
726 std::vector<llama_logit_bias> logit_bias;
727
728 // Get the token for the piece "World".
729 const std::string piece = "World";
730 std::vector<llama_token> tokens(16);
731 llama_tokenize(vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
732
733 llama_token bias_token = tokens[0];
734 // TODO: biasing too much here makes the Vulkan sampling fail - should be investigated further
735 // https://github.com/ggml-org/llama.cpp/actions/runs/20894267644/job/60030252675?pr=18753#step:3:23350
736 //logit_bias.push_back({ bias_token, +100.0f });
737 logit_bias.push_back({ bias_token, +10.0f });
738
739 printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
740
741 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
742 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
743 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
744 llama_vocab_n_tokens(vocab),
745 logit_bias.size(),
746 logit_bias.data()));
747 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
748
749 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
750 { seq_id, backend_sampler_chain.get() },
751 };
752
753 test_context test_ctx(params, backend_sampler_configs);
754
755 if (!test_ctx.decode({{seq_id, "Hello"}})) {
756 GGML_ASSERT(false && "Failed to decode token");
757 }
758
759 llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
760 printf("sampled token = %d, expected = %d\n", backend_token, bias_token);
761 GGML_ASSERT(backend_token == bias_token);
762
763 printf("backend logit bias sampling test PASSED\n");
764}
765
766// This test verifies that it is possible to have two different backend sampler,
767// one that uses the backend dist sampler, and another that uses CPU dist sampler.
768static void test_backend_mixed_sampling(const test_params & params) {
769 struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
770 llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
771 llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
772
773 int k = 40;
774 struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
775 llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
776 llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k));
777
778 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
779 { 0, sampler_chain_0.get() },
780 { 1, sampler_chain_1.get() }
781 };
782
783 test_context test_ctx(params, backend_sampler_configs);
784
785 std::map<llama_seq_id, std::string> prompts = {
786 {0, "Hello"},
787 {1, "Some"}
788 };
789
790 if (!test_ctx.decode(prompts)) {
791 GGML_ASSERT(false && "Failed to decode token");
792 }
793
794 // Verfiy sequence 0 that used the dist backend sampler.
795 {
796 int32_t batch_idx = test_ctx.idx_for_seq(0);
797 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
798 const std::string token_str = test_ctx.token_to_piece(token, false);
799 printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
800 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
801 //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
802 //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
803 }
804
805 // Verfiy sequence 1 that used the top-k backend sampler.
806 {
807 int32_t batch_idx = test_ctx.idx_for_seq(1);
808 float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
809 GGML_ASSERT(logits != nullptr);
810 size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
811 GGML_ASSERT(n_logits == (size_t) k);
812 GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
813 }
814
815 printf("backend mixed sampling test PASSED\n");
816}
817
818static void test_backend_set_sampler(const test_params & params) {
819 const int seq_id = 0;
820 const int32_t seed = 88;
821
822 struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
823 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
824 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
825 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
826
827 test_context test_ctx(params, backend_sampler_configs);
828
829 if (!test_ctx.decode({{seq_id, "Hello"}})) {
830 GGML_ASSERT(false && "Failed to decode token");
831 }
832
833 int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
834
835 // Sample using backend sampler configured above
836 llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
837 const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
838 printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
839
840 // Now clear the backend sampler for this sequence.
841 llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
842 printf("Cleared backend sampler for seq_id %d\n", seq_id);
843
844 // Sample using CPU sampler
845 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
846 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
847 llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
848
849 std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
850 if (!test_ctx.decode_tokens(tokens)) {
851 GGML_ASSERT(false && "Failed to decode token");
852 }
853
854 // Should not have any sampled token or probs after clearing the backend sampler.
855 const int32_t idx = test_ctx.idx_for_seq(seq_id);
856 GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
857 GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
858
859 // Sample the token using the CPU sampler chain.
860 llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
861 const std::string token2_str = test_ctx.token_to_piece(token2, false);
862 printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
863 std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
864
865 // Set a new backend sampler for the sequence.
866 struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
867 llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
868 llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
869 llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
870 llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
871
872 if (!test_ctx.decode_tokens(tokens2)) {
873 GGML_ASSERT(false && "Failed to decode token");
874 }
875
876 llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
877 const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
878 printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
879
880 printf("backend set sampler test PASSED\n");
881}
882
883static void test_backend_cpu_mixed_batch(const test_params & params) {
884 // Sequence 0 uses backend sampling
885 struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
886 llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
887 llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
888
889 std::vector<llama_sampler_seq_config> backend_sampler_configs = {
890 { 0, sampler_chain_0.get() },
891 };
892
893 // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
894 test_context test_ctx(params, backend_sampler_configs, 2);
895
896 std::map<llama_seq_id, std::string> prompts = {
897 {0, "Hello"}, // Will use backend sampling
898 {1, "Some"} // Will use CPU sampling
899 };
900
901 if (!test_ctx.decode(prompts)) {
902 GGML_ASSERT(false && "Failed to decode token");
903 }
904
905 // Verify sequence 0 (backend sampled)
906 {
907 int32_t batch_idx = test_ctx.idx_for_seq(0);
908 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
909 const std::string token_str = test_ctx.token_to_piece(token, false);
910 printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
911 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
912 }
913
914 // Verify sequence 1 (CPU sampled)
915 {
916 int32_t batch_idx = test_ctx.idx_for_seq(1);
917
918 llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
919 GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
920
921 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
922 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
923 llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
924
925 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
926 const std::string token_str = test_ctx.token_to_piece(token, false);
927 printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
928 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
929 }
930
931 // Clear/remove the backend sampler, and sample again
932 {
933 // clear the backend sampler for seq 0 so that there are no backend
934 // samplers.
935 llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
936
937 // Create a CPU sampler and verify we can sampler from it.
938 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
939 llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
940 llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
941
942 int32_t batch_idx = test_ctx.idx_for_seq(1);
943 llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
944 if (!test_ctx.decode_token(token, 1)) {
945 GGML_ASSERT(false && "Failed to decode token");
946 }
947 }
948
949 // Set a backend sampler so that we can verify that it can be reset
950 {
951 struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
952 llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
953 llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
954
955 llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
956
957 if (!test_ctx.decode_token(3834, 0)) {
958 GGML_ASSERT(false && "Failed to decode token");
959 }
960
961 int32_t batch_idx = test_ctx.idx_for_seq(0);
962 llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
963 const std::string token_str = test_ctx.token_to_piece(token, false);
964 printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
965 GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
966 }
967
968 printf("backend-cpu mixed batch test PASSED\n");
969}
970
971static void test_backend_max_outputs(const test_params & params) {
972 const int seq_id = 0;
973 const int32_t seed = 88;
974
975 llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
976 llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
977 llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
978 std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
979
980 test_context test_ctx(params, backend_sampler_configs);
981
982 llama_batch batch = llama_batch_init(512, 0, 1);
983 std::string prompt = "Hello";
984
985 std::vector<llama_token> tokens;
986 tokens.push_back(llama_vocab_bos(test_ctx.vocab));
987
988 std::vector<llama_token> prompt_tokens(32);
989 int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
990 prompt_tokens.data(), prompt_tokens.size(),
991 false, false);
992 for (int i = 0; i < n_tokens; i++) {
993 tokens.push_back(prompt_tokens[i]);
994 }
995
996 for (size_t i = 0; i < tokens.size(); i++) {
997 // set all tokens as output to trigger error
998 common_batch_add(batch, tokens[i], i, { seq_id }, true);
999 }
1000
1001 printf(">>> test_max_outputs expected error start:\n");
1002 const int ret = llama_decode(test_ctx.ctx.get(), batch);
1003 GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
1004 printf("<<< test_max_outputs expected error end.\n");
1005 llama_batch_free(batch);
1006
1007 printf("backend max outputs test PASSED\n");
1008}
1009
1010struct backend_test_case {
1011 std::string name;
1012 void (*fn)(const test_params &);
1013 bool enabled_by_default;
1014};
1015
1016static const backend_test_case BACKEND_TESTS[] = {
1017 { "greedy", test_backend_greedy_sampling, true },
1018 { "logit_bias", test_backend_logit_bias_sampling, true },
1019 { "temp", test_backend_temp_sampling, true },
1020 { "temp_ext", test_backend_temp_ext_sampling, true },
1021 { "top_k", test_backend_top_k_sampling, true },
1022 { "multi_sequence", test_backend_multi_sequence_sampling, true },
1023 { "dist", test_backend_dist_sampling, true },
1024 { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
1025 { "set_sampler", test_backend_set_sampler, true },
1026 { "max_outputs", test_backend_max_outputs, true },
1027 { "mixed", test_backend_mixed_sampling, true },
1028 { "min_p", test_backend_min_p_sampling, true },
1029 { "cpu_mixed", test_backend_cpu_mixed_batch, true },
1030 { "top_p", test_backend_top_p_sampling, true },
1031};
1032
1033static test_args parse_cli(int argc, char ** argv) {
1034 test_args out;
1035
1036 for (int i = 1; i < argc; ++i) {
1037 const char * arg = argv[i];
1038
1039 if (std::strcmp(arg, "--test") == 0) {
1040 if (i + 1 >= argc) {
1041 fprintf(stderr, "--test expects a value\n");
1042 exit(EXIT_FAILURE);
1043 }
1044 out.test = argv[++i];
1045 continue;
1046 }
1047 if (std::strncmp(arg, "--test=", 7) == 0) {
1048 out.test = arg + 7;
1049 continue;
1050 }
1051 if (std::strcmp(arg, "--model") == 0) {
1052 if (i + 1 >= argc) {
1053 fprintf(stderr, "--model expects a value\n");
1054 exit(EXIT_FAILURE);
1055 }
1056 out.model = argv[++i];
1057 continue;
1058 }
1059 if (std::strncmp(arg, "--model=", 8) == 0) {
1060 out.model = arg + 8;
1061 continue;
1062 }
1063 if (std::strcmp(arg, "--device") == 0) {
1064 if (i + 1 >= argc) {
1065 fprintf(stderr, "--device expects a value (cpu or gpu)\n");
1066 exit(EXIT_FAILURE);
1067 }
1068 out.device = argv[++i];
1069 continue;
1070 }
1071 if (std::strncmp(arg, "--device=", 9) == 0) {
1072 out.device = arg + 9;
1073 continue;
1074 }
1075 if (out.model.empty()) {
1076 out.model = arg;
1077 continue;
1078 }
1079
1080 fprintf(stderr, "Unexpected argument: %s\n", arg);
1081 exit(EXIT_FAILURE);
1082 }
1083
1084 if (out.device != "cpu" && out.device != "gpu" && out.device != "auto") {
1085 fprintf(stderr, "Invalid device '%s'. Must be 'cpu', 'gpu' or 'auto'\n", out.device.c_str());
1086 exit(EXIT_FAILURE);
1087 }
1088
1089 return out;
1090}
1091
1092static std::vector<const backend_test_case *> collect_tests_to_run(const std::string & requested) {
1093 std::vector<const backend_test_case *> selected;
1094
1095 if (!requested.empty()) {
1096 for (const auto & test : BACKEND_TESTS) {
1097 if (test.name == requested) {
1098 selected.push_back(&test);
1099 break;
1100 }
1101 }
1102 if (selected.empty()) {
1103 fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested.c_str());
1104 for (const auto & test : BACKEND_TESTS) {
1105 fprintf(stderr, " %s\n", test.name.c_str());
1106 }
1107 exit(EXIT_FAILURE);
1108 }
1109 } else {
1110 for (const auto & test : BACKEND_TESTS) {
1111 if (test.enabled_by_default) {
1112 selected.push_back(&test);
1113 }
1114 }
1115 }
1116
1117 if (selected.empty()) {
1118 fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n");
1119 }
1120
1121 return selected;
1122}
1123
1124static void run_tests(const std::vector<const backend_test_case *> & tests, const test_params & args) {
1125 for (const auto & test : tests) {
1126 fprintf(stderr, "\n=== %s ===\n", test->name.c_str());
1127 try {
1128 test->fn(args);
1129 } catch (const std::exception & e) {
1130 fprintf(stderr, "Error running test '%s': %s\n", test->name.c_str(), e.what());
1131 exit(EXIT_FAILURE);
1132 }
1133 }
1134}
1135
1136int main(int argc, char ** argv) {
1137 test_args args = parse_cli(argc, argv);
1138
1139 if (args.model.empty()) {
1140 args.model = get_model_or_exit(1, argv);
1141 }
1142
1143 {
1144 std::ifstream file(args.model);
1145 if (!file.is_open()) {
1146 fprintf(stderr, "no model '%s' found\n", args.model.c_str());
1147 return EXIT_FAILURE;
1148 }
1149 }
1150
1151 fprintf(stderr, "using '%s'\n", args.model.c_str());
1152
1153 llama_backend_init();
1154
1155 test_params params = {
1156 /*.model =*/ load_model(args),
1157 };
1158
1159 const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
1160 if (!tests.empty()) {
1161 run_tests(tests, params);
1162 }
1163
1164 return 0;
1165}