1#include "ggml.h"
   2#include "llama.h"
   3#include "llama-cpp.h"
   4#include "get-model.h"
   5#include "common.h"
   6
   7#ifdef NDEBUG
   8#undef NDEBUG
   9#endif
  10
  11#include <algorithm>
  12#include <cstdlib>
  13#include <cstring>
  14#include <fstream>
  15#include <map>
  16#include <string>
  17#include <unordered_map>
  18#include <vector>
  19
  20struct test_args {
  21    std::string model;
  22    std::string test;
  23    std::string device = "auto";
  24};
  25
  26struct test_params {
  27    llama_model_ptr model;
  28};
  29
  30static llama_model_ptr load_model(const test_args & args) {
  31    auto mparams = llama_model_default_params();
  32
  33    ggml_backend_dev_t devs[2] = { nullptr, nullptr };
  34
  35    if (args.device != "auto") {
  36        if (args.device == "gpu") {
  37            devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
  38
  39            if (devs[0] == nullptr) {
  40                fprintf(stderr, "Error: GPU requested but not available\n");
  41                return nullptr;
  42            }
  43
  44            mparams.n_gpu_layers = 999;
  45        } else if (args.device == "cpu") {
  46            devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
  47
  48            mparams.n_gpu_layers = 0;
  49        } else {
  50            fprintf(stderr, "Error: invalid device '%s'\n", args.device.c_str());
  51            return nullptr;
  52        }
  53
  54        mparams.devices = devs;
  55
  56        fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
  57    }
  58
  59    llama_model_ptr res;
  60
  61    res.reset(llama_model_load_from_file(args.model.c_str(), mparams));
  62
  63    if (!res) {
  64        fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model.c_str());
  65        return nullptr;
  66    }
  67
  68    return res;
  69}
  70
  71struct test_context {
  72    llama_context_ptr ctx;
  73
  74    int n_vocab = 0;
  75
  76    const llama_vocab * vocab = nullptr;
  77
  78    std::unordered_map<llama_seq_id, int32_t> seq_positions;
  79    std::unordered_map<llama_seq_id, int32_t> last_batch_info;
  80
  81    test_context(const test_params & params, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
  82        auto * model = params.model.get();
  83
  84        GGML_ASSERT(model);
  85        GGML_ASSERT(!ctx);
  86
  87        llama_context_params cparams = llama_context_default_params();
  88        cparams.n_ctx = 512;
  89        cparams.n_batch = 512;
  90        cparams.samplers = configs.data();
  91        cparams.n_samplers = configs.size();
  92
  93        // If n_seq_max is not specified, calculate it from configs
  94        if (n_seq_max < 0) {
  95            int32_t max_seq_id = 0;
  96            for (const auto & config : configs) {
  97                max_seq_id = std::max(config.seq_id, max_seq_id);
  98            }
  99            cparams.n_seq_max = max_seq_id + 1;
 100        } else {
 101            cparams.n_seq_max = n_seq_max;
 102        }
 103
 104        ctx.reset(llama_init_from_model(model, cparams));
 105        if (!ctx) {
 106            throw std::runtime_error("failed to create context");
 107        }
 108
 109        llama_set_warmup(ctx.get(), false);
 110
 111        vocab = llama_model_get_vocab(model);
 112        n_vocab = llama_vocab_n_tokens(vocab);
 113    }
 114
 115    bool decode(const std::map<llama_seq_id, std::string> & prompts) {
 116        GGML_ASSERT(ctx);
 117
 118        last_batch_info.clear();
 119        llama_batch batch = llama_batch_init(512, 0, prompts.size());
 120
 121        for (const auto & [seq_id, prompt] : prompts) {
 122            std::vector<llama_token> tokens;
 123            tokens.push_back(llama_vocab_bos(vocab));
 124
 125            std::vector<llama_token> prompt_tokens(32);
 126            int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
 127                                           prompt_tokens.data(), prompt_tokens.size(),
 128                                           false, false);
 129            if (n_tokens < 0) {
 130                fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
 131                llama_batch_free(batch);
 132                return false;
 133            }
 134
 135            for (int i = 0; i < n_tokens; i++) {
 136                tokens.push_back(prompt_tokens[i]);
 137            }
 138
 139            if (seq_positions.find(seq_id) == seq_positions.end()) {
 140                seq_positions[seq_id] = 0;
 141            }
 142
 143            int32_t start_pos = seq_positions[seq_id];
 144            for (size_t i = 0; i < tokens.size(); i++) {
 145                common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
 146            }
 147
 148            seq_positions[seq_id] = start_pos + tokens.size();
 149        }
 150
 151
 152        printf("Batch contents:\n");
 153        printf("n_tokens: %d\n", batch.n_tokens);
 154        for (int i = 0; i < batch.n_tokens; i++) {
 155            printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
 156
 157            for (int j = 0; j < batch.n_seq_id[i]; j++) {
 158                printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
 159            }
 160            printf("], logits=%d\n", batch.logits[i]);
 161        }
 162
 163        if (llama_decode(ctx.get(), batch) != 0) {
 164            fprintf(stderr, "Warning: llama_decode failed\n");
 165            llama_batch_free(batch);
 166            return false;
 167        }
 168
 169        // Build mapping from seq id to batch token idx
 170        for (int i = 0; i < batch.n_tokens; i++) {
 171            if (batch.logits[i]) {
 172                llama_seq_id seq_id = batch.seq_id[i][0];
 173                last_batch_info[seq_id] = i;
 174            }
 175        }
 176
 177        llama_batch_free(batch);
 178        return true;
 179    }
 180
 181    int32_t idx_for_seq(llama_seq_id seq_id) {
 182        auto it = last_batch_info.find(seq_id);
 183        if (it == last_batch_info.end()) {
 184            fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id);
 185            return -1;
 186        }
 187        return it->second;
 188    }
 189
 190    void update_batch_info(const llama_batch & batch) {
 191        last_batch_info.clear();
 192        for (int i = 0; i < batch.n_tokens; i++) {
 193            if (batch.logits[i]) {
 194                llama_seq_id cur_seq = batch.seq_id[i][0];
 195                last_batch_info[cur_seq] = i;
 196            }
 197        }
 198    }
 199
 200    bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
 201        GGML_ASSERT(ctx);
 202
 203        llama_batch batch = llama_batch_init(1, 0, 1);
 204        int32_t pos = seq_positions[seq_id];
 205        common_batch_add(batch, token, pos, { seq_id }, true);
 206
 207        if (llama_decode(ctx.get(), batch) != 0) {
 208            fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
 209            llama_batch_free(batch);
 210            return false;
 211        }
 212
 213        update_batch_info(batch);
 214
 215        seq_positions[seq_id]++;
 216        llama_batch_free(batch);
 217
 218        return true;
 219    }
 220
 221    bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
 222        GGML_ASSERT(ctx);
 223
 224        llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
 225
 226        for (const auto & [seq_id, token] : seq_tokens) {
 227            int32_t pos = seq_positions[seq_id];
 228            common_batch_add(batch, token, pos, { seq_id }, true);
 229        }
 230
 231        if (llama_decode(ctx.get(), batch) != 0) {
 232            fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
 233            llama_batch_free(batch);
 234            return false;
 235        }
 236
 237        for (const auto & [seq_id, _] : seq_tokens) {
 238            seq_positions[seq_id]++;
 239        }
 240
 241        update_batch_info(batch);
 242
 243        llama_batch_free(batch);
 244
 245        return true;
 246    }
 247
 248    std::string token_to_piece(llama_token token, bool special) const {
 249        std::string piece;
 250        piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
 251        const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
 252        if (n_chars < 0) {
 253            piece.resize(-n_chars);
 254            int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
 255            GGML_ASSERT(check == -n_chars);
 256        } else {
 257            piece.resize(n_chars);
 258        }
 259
 260        return piece;
 261    }
 262};
 263
 264static void test_backend_greedy_sampling(const test_params & params) {
 265    const int seq_id = 0;
 266
 267    struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
 268    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params));
 269
 270    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
 271    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 272
 273    test_context test_ctx(params, backend_sampler_configs);
 274
 275    if (!test_ctx.decode({{seq_id, "Some"}})) {
 276        GGML_ASSERT(false && "Failed to decode token");
 277    }
 278
 279    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 280
 281    llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 282    printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
 283    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 284
 285    token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
 286    printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
 287    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 288
 289    for (int i = 0; i < 10; i++) {
 290        int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
 291        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
 292        printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
 293        if (!test_ctx.decode_token(token, 0)) {
 294            GGML_ASSERT(false && "Failed to decode token");
 295        }
 296    }
 297}
 298
 299static void test_backend_top_k_sampling(const test_params & params) {
 300    const int seq_id = 0;
 301    const int32_t k = 8;
 302    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 303    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 304    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
 305    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 306
 307    test_context test_ctx(params, backend_sampler_configs);
 308
 309    if (!test_ctx.decode({{seq_id, "Hello"}})) {
 310        GGML_ASSERT(false && "Failed to decode token");
 311    }
 312
 313    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 314
 315    float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
 316    uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 317    for (size_t i = 0; i < n_logits; ++i) {
 318        printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
 319    }
 320
 321    llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
 322    uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
 323    for (size_t i = 0; i < n_candidates; ++i) {
 324        printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
 325               test_ctx.token_to_piece(candidates[i], false).c_str());
 326    }
 327
 328    // Sample using CPU sampler for verification that it is possible to do hybrid
 329    // sampling, first top_k on the backend and then dist on the CPU.
 330    struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 331    llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 332    GGML_ASSERT(chain->iface->backend_apply != nullptr);
 333
 334    llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
 335    llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 336    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 337
 338    printf("backend top-k hybrid sampling test PASSED\n");
 339}
 340
 341static void test_backend_temp_sampling(const test_params & params) {
 342    {
 343        const float temp_0 = 0.8f;
 344        struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
 345        llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0));
 346        llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0));
 347
 348        const float temp_1 = 0.1f;
 349        struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
 350        llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1));
 351        llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1));
 352
 353        std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 354            { 0, backend_sampler_chain_0.get() },
 355            { 1, backend_sampler_chain_1.get() }
 356        };
 357
 358        test_context test_ctx(params, backend_sampler_configs);
 359
 360        if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
 361            GGML_ASSERT(false && "Failed to decode token");
 362        }
 363
 364        // Verfify sequence 0
 365        {
 366            int32_t batch_idx = test_ctx.idx_for_seq(0);
 367            int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 368            GGML_ASSERT(n_logits == test_ctx.n_vocab);
 369
 370            // Sample from sequence 0 using CPU sampler
 371            struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 372            llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 373            llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
 374
 375            llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 376            const std::string token_str = test_ctx.token_to_piece(token, false);
 377            printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
 378            GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 379        }
 380
 381
 382        // Verfify sequence 1
 383        {
 384            int32_t batch_idx = test_ctx.idx_for_seq(1);
 385
 386            // Sample from sequence 1 using CPU sampler
 387            struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 388            llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 389            llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
 390
 391            llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 392            const std::string token_str = test_ctx.token_to_piece(token, false);
 393            printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
 394            GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 395        }
 396    }
 397
 398    // lambda to testing non-positive temperature values.
 399    auto test_argmax_temp = [&](float temp) {
 400        printf("\nTesting temperature = %.1f\n", temp);
 401
 402        int seq_id = 0;
 403        struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 404        llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 405        llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp));
 406
 407        std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 408            { seq_id, backend_sampler_chain.get() },
 409        };
 410
 411        test_context test_ctx(params, backend_sampler_configs);
 412
 413        if (!test_ctx.decode({{seq_id, "Once"}})) {
 414            GGML_ASSERT(false && "Failed to decode token");
 415        }
 416
 417        int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 418
 419        uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 420        GGML_ASSERT(n_logits == 1);
 421    };
 422
 423    test_argmax_temp(0.0f);
 424    test_argmax_temp(-1.0f);
 425
 426    printf("backend temp sampling test PASSED\n");
 427}
 428
 429static void test_backend_temp_ext_sampling(const test_params & params) {
 430    {
 431        int seq_id = 0;
 432        const float temp = 0.8f;
 433        const float delta = 0.5f;
 434        const float exponent = 1.5f;
 435        struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 436        llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 437        llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
 438
 439        std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 440            { seq_id, backend_sampler_chain.get() },
 441        };
 442
 443        test_context test_ctx(params, backend_sampler_configs);
 444
 445        if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
 446            GGML_ASSERT(false && "Failed to decode token");
 447        }
 448
 449        // Verify sequence 0
 450        {
 451            int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 452            int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 453            GGML_ASSERT(n_logits == test_ctx.n_vocab);
 454        }
 455    }
 456
 457    // lambda to testing non-positive temp/delta/exponent values.
 458    auto test_argmax_temp = [&](float temp, float delta, float exponent) {
 459        printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
 460
 461        int seq_id = 0;
 462        struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 463        llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 464        llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
 465
 466        std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 467            { seq_id, backend_sampler_chain.get() },
 468        };
 469
 470        test_context test_ctx(params, backend_sampler_configs);
 471
 472        if (!test_ctx.decode({{seq_id, "Once"}})) {
 473            GGML_ASSERT(false && "Failed to decode token");
 474        }
 475
 476        int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 477
 478        uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 479
 480        if (temp <= 0.0f && delta >= 0.0f) {
 481            GGML_ASSERT(n_logits == 1);
 482        } else {
 483            GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
 484        }
 485    };
 486
 487    test_argmax_temp(0.0f,  0.3f, 1.0f); // Greedy (temp=0)
 488    test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
 489    test_argmax_temp(0.8f,  0.0f, 2.0f); // Temperature scaling
 490
 491    printf("backend temp_ext sampling test PASSED\n");
 492}
 493
 494static void test_backend_min_p_sampling(const test_params & params) {
 495    const int seq_id = 0;
 496    const float p = 0.1;
 497    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 498    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 499    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
 500    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 501
 502    test_context test_ctx(params, backend_sampler_configs);
 503
 504    if (!test_ctx.decode({{seq_id, "Hello"}})) {
 505        GGML_ASSERT(false && "Failed to decode token");
 506    }
 507
 508    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 509
 510    float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
 511    uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 512
 513    // Print the logits that are above the min-p threshold
 514    std::vector<float> filtered_logits;
 515    for (size_t i = 0; i < n_logits; ++i) {
 516        if (logits[i] > -1e9f) {
 517            filtered_logits.push_back(logits[i]);
 518            //printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
 519        }
 520    }
 521    GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
 522
 523    // Sample using CPU sampler for verification to inspect they are reasonable
 524    struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 525    llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 526    llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
 527
 528    llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 529    const std::string token_str = test_ctx.token_to_piece(token, false);
 530    printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
 531    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 532
 533    // Decode and sampler 10 more tokens
 534    for (int i = 0; i < 10; i++) {
 535        int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
 536        llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
 537        printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
 538        if (!test_ctx.decode_token(token, 0)) {
 539            GGML_ASSERT(false && "Failed to decode token");
 540        }
 541    }
 542
 543    printf("min-p sampling test PASSED\n");
 544}
 545
 546static void test_backend_top_p_sampling(const test_params & params) {
 547    const int seq_id = 0;
 548    const float p = 0.9;
 549    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 550    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 551    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
 552    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 553
 554    test_context test_ctx(params, backend_sampler_configs);
 555
 556    if (!test_ctx.decode({{seq_id, "Hello"}})) {
 557        return;
 558    }
 559
 560    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 561
 562    float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
 563    uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 564
 565    // Print the logits that are above the min-p threshold
 566    std::vector<float> filtered_logits;
 567    for (size_t i = 0; i < n_logits; ++i) {
 568        if (logits[i] > -1e9f) {
 569            filtered_logits.push_back(logits[i]);
 570        }
 571    }
 572    GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
 573    GGML_ASSERT(filtered_logits.size() > 0);
 574
 575    // Sample using CPU sampler for verification to inspect they are reasonable
 576    struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 577    llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 578    llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
 579
 580    llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 581    const std::string token_str = test_ctx.token_to_piece(token, false);
 582    printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
 583    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 584
 585    // Decode and sampler 10 more tokens
 586    for (int i = 0; i < 10; i++) {
 587        int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
 588        llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
 589        printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
 590        test_ctx.decode_token(token, 0);
 591    }
 592
 593    printf("top-p sampling test PASSED\n");
 594}
 595
 596static void test_backend_multi_sequence_sampling(const test_params & params) {
 597    struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
 598    llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
 599    llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
 600
 601    struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
 602    llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
 603    llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f));
 604    llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy());
 605
 606    std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 607        { 0, sampler_chain_0.get() },
 608        { 1, sampler_chain_1.get() }
 609    };
 610
 611    test_context test_ctx(params, backend_sampler_configs);
 612
 613    std::map<llama_seq_id, std::string> prompts = {
 614        {0, "Hello"},
 615        {1, "Some"}
 616    };
 617
 618    if (!test_ctx.decode(prompts)) {
 619        GGML_ASSERT(false && "Failed to decode token");
 620    }
 621
 622    // Verfiy sequence 0
 623    {
 624        int32_t batch_idx = test_ctx.idx_for_seq(0);
 625        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 626        const std::string token_str = test_ctx.token_to_piece(token, false);
 627        printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
 628        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 629    }
 630
 631    // Verify sequence 1
 632    {
 633        int32_t batch_idx= test_ctx.idx_for_seq(1);
 634        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 635        const std::string token_str = test_ctx.token_to_piece(token, false);
 636        printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
 637        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 638    }
 639
 640    // Generate tokens for each sequence
 641    printf("\nMulti-sequence generation:\n");
 642    for (int step = 0; step < 4; step++) {
 643        std::map<llama_seq_id, llama_token> tokens;
 644
 645        for (llama_seq_id seq_id : {0, 1}) {
 646            int32_t idx = test_ctx.idx_for_seq(seq_id);
 647            llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
 648            const std::string token_str = test_ctx.token_to_piece(token, false);
 649            printf("  Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
 650            tokens[seq_id] = token;
 651        }
 652
 653        // Decode all tokens in a single batch
 654        if (!test_ctx.decode_tokens(tokens)) {
 655            GGML_ASSERT(false && "Failed to decode token");
 656        }
 657    }
 658
 659    printf("backend multi-sequence sampling test PASSED\n");
 660}
 661
 662static void test_backend_dist_sampling(const test_params & params) {
 663    const int seq_id = 189;
 664    const int32_t seed = 88;
 665
 666    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 667    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 668    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
 669    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 670
 671    test_context test_ctx(params, backend_sampler_configs);
 672
 673    if (!test_ctx.decode({{seq_id, "Some"}})) {
 674        GGML_ASSERT(false && "Failed to decode token");
 675    }
 676
 677    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 678    llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 679    printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
 680    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 681    //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
 682
 683    token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
 684    printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
 685    GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 686
 687    printf("backend dist sampling test PASSED\n");
 688}
 689
 690static void test_backend_dist_sampling_and_cpu(const test_params & params) {
 691    const int seq_id = 0;
 692    const int32_t seed = 88;
 693
 694    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 695    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 696    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
 697    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 698
 699    test_context test_ctx(params, backend_sampler_configs);
 700
 701    if (!test_ctx.decode({{seq_id, "Some"}})) {
 702        GGML_ASSERT(false && "Failed to decode token");
 703    }
 704
 705    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 706
 707    // Sample using CPU sampler
 708    struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 709    llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 710    llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
 711
 712    llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 713    llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 714    printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
 715    GGML_ASSERT(backend_token == cpu_token);
 716
 717    printf("backend dist & cpu sampling test PASSED\n");
 718}
 719
 720static void test_backend_logit_bias_sampling(const test_params & params) {
 721    const auto * model = params.model.get();
 722    const auto * vocab = llama_model_get_vocab(model);
 723
 724    const int seq_id = 0;
 725
 726    std::vector<llama_logit_bias> logit_bias;
 727
 728    // Get the token for the piece "World".
 729    const std::string piece = "World";
 730    std::vector<llama_token> tokens(16);
 731    llama_tokenize(vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
 732
 733    llama_token bias_token = tokens[0];
 734    // TODO: biasing too much here makes the Vulkan sampling fail - should be investigated further
 735    //       https://github.com/ggml-org/llama.cpp/actions/runs/20894267644/job/60030252675?pr=18753#step:3:23350
 736    //logit_bias.push_back({ bias_token, +100.0f });
 737    logit_bias.push_back({ bias_token, +10.0f });
 738
 739    printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
 740
 741    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 742    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 743    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
 744                llama_vocab_n_tokens(vocab),
 745                logit_bias.size(),
 746                logit_bias.data()));
 747    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
 748
 749    std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 750        { seq_id, backend_sampler_chain.get() },
 751    };
 752
 753    test_context test_ctx(params, backend_sampler_configs);
 754
 755    if (!test_ctx.decode({{seq_id, "Hello"}})) {
 756        GGML_ASSERT(false && "Failed to decode token");
 757    }
 758
 759    llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
 760    printf("sampled token = %d, expected = %d\n", backend_token, bias_token);
 761    GGML_ASSERT(backend_token == bias_token);
 762
 763    printf("backend logit bias sampling test PASSED\n");
 764}
 765
 766// This test verifies that it is possible to have two different backend sampler,
 767// one that uses the backend dist sampler, and another that uses CPU dist sampler.
 768static void test_backend_mixed_sampling(const test_params & params) {
 769    struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
 770    llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
 771    llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
 772
 773    int k = 40;
 774    struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
 775    llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
 776    llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k));
 777
 778    std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 779        { 0, sampler_chain_0.get() },
 780        { 1, sampler_chain_1.get() }
 781    };
 782
 783    test_context test_ctx(params, backend_sampler_configs);
 784
 785    std::map<llama_seq_id, std::string> prompts = {
 786        {0, "Hello"},
 787        {1, "Some"}
 788    };
 789
 790    if (!test_ctx.decode(prompts)) {
 791        GGML_ASSERT(false && "Failed to decode token");
 792    }
 793
 794    // Verfiy sequence 0 that used the dist backend sampler.
 795    {
 796        int32_t batch_idx = test_ctx.idx_for_seq(0);
 797        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 798        const std::string token_str = test_ctx.token_to_piece(token, false);
 799        printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
 800        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 801        //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
 802        //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
 803    }
 804
 805    // Verfiy sequence 1 that used the top-k backend sampler.
 806    {
 807        int32_t batch_idx = test_ctx.idx_for_seq(1);
 808        float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
 809        GGML_ASSERT(logits != nullptr);
 810        size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
 811        GGML_ASSERT(n_logits == (size_t) k);
 812        GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
 813    }
 814
 815    printf("backend mixed sampling test PASSED\n");
 816}
 817
 818static void test_backend_set_sampler(const test_params & params) {
 819    const int seq_id = 0;
 820    const int32_t seed = 88;
 821
 822    struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 823    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 824    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
 825    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 826
 827    test_context test_ctx(params, backend_sampler_configs);
 828
 829    if (!test_ctx.decode({{seq_id, "Hello"}})) {
 830        GGML_ASSERT(false && "Failed to decode token");
 831    }
 832
 833    int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
 834
 835    // Sample using backend sampler configured above
 836    llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 837    const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
 838    printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
 839
 840    // Now clear the backend sampler for this sequence.
 841    llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
 842    printf("Cleared backend sampler for seq_id %d\n", seq_id);
 843
 844    // Sample using CPU sampler
 845    struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 846    llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 847    llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
 848
 849    std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
 850    if (!test_ctx.decode_tokens(tokens)) {
 851        GGML_ASSERT(false && "Failed to decode token");
 852    }
 853
 854    // Should not have any sampled token or probs after clearing the backend sampler.
 855    const int32_t idx = test_ctx.idx_for_seq(seq_id);
 856    GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
 857    GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
 858
 859    // Sample the token using the CPU sampler chain.
 860    llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
 861    const std::string token2_str = test_ctx.token_to_piece(token2, false);
 862    printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
 863    std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
 864
 865    // Set a new backend sampler for the sequence.
 866    struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
 867    llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
 868    llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
 869    llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
 870    llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
 871
 872    if (!test_ctx.decode_tokens(tokens2)) {
 873        GGML_ASSERT(false && "Failed to decode token");
 874    }
 875
 876    llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
 877    const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
 878    printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
 879
 880    printf("backend set sampler test PASSED\n");
 881}
 882
 883static void test_backend_cpu_mixed_batch(const test_params & params) {
 884    // Sequence 0 uses backend sampling
 885    struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
 886    llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
 887    llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
 888
 889    std::vector<llama_sampler_seq_config> backend_sampler_configs = {
 890        { 0, sampler_chain_0.get() },
 891    };
 892
 893    // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
 894    test_context test_ctx(params, backend_sampler_configs, 2);
 895
 896    std::map<llama_seq_id, std::string> prompts = {
 897        {0, "Hello"}, // Will use backend sampling
 898        {1, "Some"}   // Will use CPU sampling
 899    };
 900
 901    if (!test_ctx.decode(prompts)) {
 902        GGML_ASSERT(false && "Failed to decode token");
 903    }
 904
 905    // Verify sequence 0 (backend sampled)
 906    {
 907        int32_t batch_idx = test_ctx.idx_for_seq(0);
 908        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 909        const std::string token_str = test_ctx.token_to_piece(token, false);
 910        printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
 911        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 912    }
 913
 914    // Verify sequence 1 (CPU sampled)
 915    {
 916        int32_t batch_idx = test_ctx.idx_for_seq(1);
 917
 918        llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 919        GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
 920
 921        struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 922        llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 923        llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
 924
 925        llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 926        const std::string token_str = test_ctx.token_to_piece(token, false);
 927        printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
 928        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 929    }
 930
 931    // Clear/remove the backend sampler, and sample again
 932    {
 933        // clear the backend sampler for seq 0 so that there are no backend
 934        // samplers.
 935        llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
 936
 937        // Create a CPU sampler and verify we can sampler from it.
 938        struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 939        llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
 940        llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
 941
 942        int32_t batch_idx = test_ctx.idx_for_seq(1);
 943        llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
 944        if (!test_ctx.decode_token(token, 1)) {
 945            GGML_ASSERT(false && "Failed to decode token");
 946        }
 947    }
 948
 949    // Set a backend sampler so that we can verify that it can be reset
 950    {
 951        struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
 952        llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
 953        llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
 954
 955        llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
 956
 957        if (!test_ctx.decode_token(3834, 0)) {
 958            GGML_ASSERT(false && "Failed to decode token");
 959        }
 960
 961        int32_t batch_idx = test_ctx.idx_for_seq(0);
 962        llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
 963        const std::string token_str = test_ctx.token_to_piece(token, false);
 964        printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
 965        GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
 966    }
 967
 968    printf("backend-cpu mixed batch test PASSED\n");
 969}
 970
 971static void test_backend_max_outputs(const test_params & params) {
 972    const int seq_id = 0;
 973    const int32_t seed = 88;
 974
 975    llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
 976    llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
 977    llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
 978    std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
 979
 980    test_context test_ctx(params, backend_sampler_configs);
 981
 982    llama_batch batch = llama_batch_init(512, 0, 1);
 983    std::string prompt = "Hello";
 984
 985    std::vector<llama_token> tokens;
 986    tokens.push_back(llama_vocab_bos(test_ctx.vocab));
 987
 988    std::vector<llama_token> prompt_tokens(32);
 989    int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
 990                                   prompt_tokens.data(), prompt_tokens.size(),
 991                                   false, false);
 992    for (int i = 0; i < n_tokens; i++) {
 993        tokens.push_back(prompt_tokens[i]);
 994    }
 995
 996    for (size_t i = 0; i < tokens.size(); i++) {
 997        // set all tokens as output to trigger error
 998        common_batch_add(batch, tokens[i], i, { seq_id }, true);
 999    }
1000
1001    printf(">>> test_max_outputs expected error start:\n");
1002    const int ret = llama_decode(test_ctx.ctx.get(), batch);
1003    GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
1004    printf("<<< test_max_outputs expected error end.\n");
1005    llama_batch_free(batch);
1006
1007    printf("backend max outputs test PASSED\n");
1008}
1009
1010struct backend_test_case {
1011    std::string name;
1012    void (*fn)(const test_params &);
1013    bool enabled_by_default;
1014};
1015
1016static const backend_test_case BACKEND_TESTS[] = {
1017    { "greedy",          test_backend_greedy_sampling,         true  },
1018    { "logit_bias",      test_backend_logit_bias_sampling,     true  },
1019    { "temp",            test_backend_temp_sampling,           true  },
1020    { "temp_ext",        test_backend_temp_ext_sampling,       true  },
1021    { "top_k",           test_backend_top_k_sampling,          true  },
1022    { "multi_sequence",  test_backend_multi_sequence_sampling, true  },
1023    { "dist",            test_backend_dist_sampling,           true  },
1024    { "dist_and_cpu",    test_backend_dist_sampling_and_cpu,   true  },
1025    { "set_sampler",     test_backend_set_sampler,             true  },
1026    { "max_outputs",     test_backend_max_outputs,             true  },
1027    { "mixed",           test_backend_mixed_sampling,          true  },
1028    { "min_p",           test_backend_min_p_sampling,          true  },
1029    { "cpu_mixed",       test_backend_cpu_mixed_batch,         true  },
1030    { "top_p",           test_backend_top_p_sampling,          true  },
1031};
1032
1033static test_args parse_cli(int argc, char ** argv) {
1034    test_args out;
1035
1036    for (int i = 1; i < argc; ++i) {
1037        const char * arg = argv[i];
1038
1039        if (std::strcmp(arg, "--test") == 0) {
1040            if (i + 1 >= argc) {
1041                fprintf(stderr, "--test expects a value\n");
1042                exit(EXIT_FAILURE);
1043            }
1044            out.test = argv[++i];
1045            continue;
1046        }
1047        if (std::strncmp(arg, "--test=", 7) == 0) {
1048            out.test = arg + 7;
1049            continue;
1050        }
1051        if (std::strcmp(arg, "--model") == 0) {
1052            if (i + 1 >= argc) {
1053                fprintf(stderr, "--model expects a value\n");
1054                exit(EXIT_FAILURE);
1055            }
1056            out.model = argv[++i];
1057            continue;
1058        }
1059        if (std::strncmp(arg, "--model=", 8) == 0) {
1060            out.model = arg + 8;
1061            continue;
1062        }
1063        if (std::strcmp(arg, "--device") == 0) {
1064            if (i + 1 >= argc) {
1065                fprintf(stderr, "--device expects a value (cpu or gpu)\n");
1066                exit(EXIT_FAILURE);
1067            }
1068            out.device = argv[++i];
1069            continue;
1070        }
1071        if (std::strncmp(arg, "--device=", 9) == 0) {
1072            out.device = arg + 9;
1073            continue;
1074        }
1075        if (out.model.empty()) {
1076            out.model = arg;
1077            continue;
1078        }
1079
1080        fprintf(stderr, "Unexpected argument: %s\n", arg);
1081        exit(EXIT_FAILURE);
1082    }
1083
1084    if (out.device != "cpu" && out.device != "gpu" && out.device != "auto") {
1085        fprintf(stderr, "Invalid device '%s'. Must be 'cpu', 'gpu' or 'auto'\n", out.device.c_str());
1086        exit(EXIT_FAILURE);
1087    }
1088
1089    return out;
1090}
1091
1092static std::vector<const backend_test_case *> collect_tests_to_run(const std::string & requested) {
1093    std::vector<const backend_test_case *> selected;
1094
1095    if (!requested.empty()) {
1096        for (const auto & test : BACKEND_TESTS) {
1097            if (test.name == requested) {
1098                selected.push_back(&test);
1099                break;
1100            }
1101        }
1102        if (selected.empty()) {
1103            fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested.c_str());
1104            for (const auto & test : BACKEND_TESTS) {
1105                fprintf(stderr, "  %s\n", test.name.c_str());
1106            }
1107            exit(EXIT_FAILURE);
1108        }
1109    } else {
1110        for (const auto & test : BACKEND_TESTS) {
1111            if (test.enabled_by_default) {
1112                selected.push_back(&test);
1113            }
1114        }
1115    }
1116
1117    if (selected.empty()) {
1118        fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n");
1119    }
1120
1121    return selected;
1122}
1123
1124static void run_tests(const std::vector<const backend_test_case *> & tests, const test_params & args) {
1125    for (const auto & test : tests) {
1126        fprintf(stderr, "\n=== %s ===\n", test->name.c_str());
1127        try {
1128            test->fn(args);
1129        } catch (const std::exception & e) {
1130            fprintf(stderr, "Error running test '%s': %s\n", test->name.c_str(), e.what());
1131            exit(EXIT_FAILURE);
1132        }
1133    }
1134}
1135
1136int main(int argc, char ** argv) {
1137    test_args args = parse_cli(argc, argv);
1138
1139    if (args.model.empty()) {
1140        args.model = get_model_or_exit(1, argv);
1141    }
1142
1143    {
1144        std::ifstream file(args.model);
1145        if (!file.is_open()) {
1146            fprintf(stderr, "no model '%s' found\n", args.model.c_str());
1147            return EXIT_FAILURE;
1148        }
1149    }
1150
1151    fprintf(stderr, "using '%s'\n", args.model.c_str());
1152
1153    llama_backend_init();
1154
1155    test_params params = {
1156        /*.model =*/ load_model(args),
1157    };
1158
1159    const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
1160    if (!tests.empty()) {
1161        run_tests(tests, params);
1162    }
1163
1164    return 0;
1165}