diff options
Diffstat (limited to 'llama.cpp/tests/test-backend-sampler.cpp')
| -rw-r--r-- | llama.cpp/tests/test-backend-sampler.cpp | 1165 |
1 files changed, 1165 insertions, 0 deletions
diff --git a/llama.cpp/tests/test-backend-sampler.cpp b/llama.cpp/tests/test-backend-sampler.cpp new file mode 100644 index 0000000..c10bde9 --- /dev/null +++ b/llama.cpp/tests/test-backend-sampler.cpp @@ -0,0 +1,1165 @@ +#include "ggml.h" +#include "llama.h" +#include "llama-cpp.h" +#include "get-model.h" +#include "common.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include <algorithm> +#include <cstdlib> +#include <cstring> +#include <fstream> +#include <map> +#include <string> +#include <unordered_map> +#include <vector> + +struct test_args { + std::string model; + std::string test; + std::string device = "auto"; +}; + +struct test_params { + llama_model_ptr model; +}; + +static llama_model_ptr load_model(const test_args & args) { + auto mparams = llama_model_default_params(); + + ggml_backend_dev_t devs[2] = { nullptr, nullptr }; + + if (args.device != "auto") { + if (args.device == "gpu") { + devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + + if (devs[0] == nullptr) { + fprintf(stderr, "Error: GPU requested but not available\n"); + return nullptr; + } + + mparams.n_gpu_layers = 999; + } else if (args.device == "cpu") { + devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + + mparams.n_gpu_layers = 0; + } else { + fprintf(stderr, "Error: invalid device '%s'\n", args.device.c_str()); + return nullptr; + } + + mparams.devices = devs; + + fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0])); + } + + llama_model_ptr res; + + res.reset(llama_model_load_from_file(args.model.c_str(), mparams)); + + if (!res) { + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model.c_str()); + return nullptr; + } + + return res; +} + +struct test_context { + llama_context_ptr ctx; + + int n_vocab = 0; + + const llama_vocab * vocab = nullptr; + + std::unordered_map<llama_seq_id, int32_t> seq_positions; + std::unordered_map<llama_seq_id, int32_t> last_batch_info; + + test_context(const test_params & params, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) { + auto * model = params.model.get(); + + GGML_ASSERT(model); + GGML_ASSERT(!ctx); + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = 512; + cparams.n_batch = 512; + cparams.samplers = configs.data(); + cparams.n_samplers = configs.size(); + + // If n_seq_max is not specified, calculate it from configs + if (n_seq_max < 0) { + int32_t max_seq_id = 0; + for (const auto & config : configs) { + max_seq_id = std::max(config.seq_id, max_seq_id); + } + cparams.n_seq_max = max_seq_id + 1; + } else { + cparams.n_seq_max = n_seq_max; + } + + ctx.reset(llama_init_from_model(model, cparams)); + if (!ctx) { + throw std::runtime_error("failed to create context"); + } + + llama_set_warmup(ctx.get(), false); + + vocab = llama_model_get_vocab(model); + n_vocab = llama_vocab_n_tokens(vocab); + } + + bool decode(const std::map<llama_seq_id, std::string> & prompts) { + GGML_ASSERT(ctx); + + last_batch_info.clear(); + llama_batch batch = llama_batch_init(512, 0, prompts.size()); + + for (const auto & [seq_id, prompt] : prompts) { + std::vector<llama_token> tokens; + tokens.push_back(llama_vocab_bos(vocab)); + + std::vector<llama_token> prompt_tokens(32); + int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + if (n_tokens < 0) { + fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); + llama_batch_free(batch); + return false; + } + + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + if (seq_positions.find(seq_id) == seq_positions.end()) { + seq_positions[seq_id] = 0; + } + + int32_t start_pos = seq_positions[seq_id]; + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = start_pos + tokens.size(); + } + + + printf("Batch contents:\n"); + printf("n_tokens: %d\n", batch.n_tokens); + for (int i = 0; i < batch.n_tokens; i++) { + printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); + + for (int j = 0; j < batch.n_seq_id[i]; j++) { + printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); + } + printf("], logits=%d\n", batch.logits[i]); + } + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed\n"); + llama_batch_free(batch); + return false; + } + + // Build mapping from seq id to batch token idx + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id seq_id = batch.seq_id[i][0]; + last_batch_info[seq_id] = i; + } + } + + llama_batch_free(batch); + return true; + } + + int32_t idx_for_seq(llama_seq_id seq_id) { + auto it = last_batch_info.find(seq_id); + if (it == last_batch_info.end()) { + fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); + return -1; + } + return it->second; + } + + void update_batch_info(const llama_batch & batch) { + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + } + + bool decode_token(llama_token token, llama_seq_id seq_id = 0) { + GGML_ASSERT(ctx); + + llama_batch batch = llama_batch_init(1, 0, 1); + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); + llama_batch_free(batch); + return false; + } + + update_batch_info(batch); + + seq_positions[seq_id]++; + llama_batch_free(batch); + + return true; + } + + bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) { + GGML_ASSERT(ctx); + + llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); + + for (const auto & [seq_id, token] : seq_tokens) { + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + } + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); + llama_batch_free(batch); + return false; + } + + for (const auto & [seq_id, _] : seq_tokens) { + seq_positions[seq_id]++; + } + + update_batch_info(batch); + + llama_batch_free(batch); + + return true; + } + + std::string token_to_piece(llama_token token, bool special) const { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } else { + piece.resize(n_chars); + } + + return piece; + } +}; + +static void test_backend_greedy_sampling(const test_params & params) { + const int seq_id = 0; + + struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params)); + + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy()); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx); + printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } +} + +static void test_backend_top_k_sampling(const test_params & params) { + const int seq_id = 0; + const int32_t k = 8; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + for (size_t i = 0; i < n_logits; ++i) { + printf("top_k logit[%zu] = %.6f\n", i, logits[i]); + } + + llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx); + for (size_t i = 0; i < n_candidates; ++i) { + printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], + test_ctx.token_to_piece(candidates[i], false).c_str()); + } + + // Sample using CPU sampler for verification that it is possible to do hybrid + // sampling, first top_k on the backend and then dist on the CPU. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + GGML_ASSERT(chain->iface->backend_apply != nullptr); + + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend top-k hybrid sampling test PASSED\n"); +} + +static void test_backend_temp_sampling(const test_params & params) { + { + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0)); + llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1)); + llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { 0, backend_sampler_chain_0.get() }, + { 1, backend_sampler_chain_1.get() } + }; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + + // Verfify sequence 1 + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + } + + // lambda to testing non-positive temperature values. + auto test_argmax_temp = [&](float temp) { + printf("\nTesting temperature = %.1f\n", temp); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == 1); + }; + + test_argmax_temp(0.0f); + test_argmax_temp(-1.0f); + + printf("backend temp sampling test PASSED\n"); +} + +static void test_backend_temp_ext_sampling(const test_params & params) { + { + int seq_id = 0; + const float temp = 0.8f; + const float delta = 0.5f; + const float exponent = 1.5f; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + } + } + + // lambda to testing non-positive temp/delta/exponent values. + auto test_argmax_temp = [&](float temp, float delta, float exponent) { + printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + if (temp <= 0.0f && delta >= 0.0f) { + GGML_ASSERT(n_logits == 1); + } else { + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); + } + }; + + test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) + test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) + test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling + + printf("backend temp_ext sampling test PASSED\n"); +} + +static void test_backend_min_p_sampling(const test_params & params) { + const int seq_id = 0; + const float p = 0.1; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + // Print the logits that are above the min-p threshold + std::vector<float> filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + //printf("min_p logit[%zu] = %.6f\n", i, logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); + printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + printf("min-p sampling test PASSED\n"); +} + +static void test_backend_top_p_sampling(const test_params & params) { + const int seq_id = 0; + const float p = 0.9; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + // Print the logits that are above the min-p threshold + std::vector<float> filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + GGML_ASSERT(filtered_logits.size() > 0); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); + printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } + + printf("top-p sampling test PASSED\n"); +} + +static void test_backend_multi_sequence_sampling(const test_params & params) { + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy()); + + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy()); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } + }; + + test_context test_ctx(params, backend_sampler_configs); + + std::map<llama_seq_id, std::string> prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 + { + int32_t batch_idx= test_ctx.idx_for_seq(1); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Generate tokens for each sequence + printf("\nMulti-sequence generation:\n"); + for (int step = 0; step < 4; step++) { + std::map<llama_seq_id, llama_token> tokens; + + for (llama_seq_id seq_id : {0, 1}) { + int32_t idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); + tokens[seq_id] = token; + } + + // Decode all tokens in a single batch + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + printf("backend multi-sequence sampling test PASSED\n"); +} + +static void test_backend_dist_sampling(const test_params & params) { + const int seq_id = 189; + const int32_t seed = 88; + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr); + + token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend dist sampling test PASSED\n"); +} + +static void test_backend_dist_sampling_and_cpu(const test_params & params) { + const int seq_id = 0; + const int32_t seed = 88; + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); + GGML_ASSERT(backend_token == cpu_token); + + printf("backend dist & cpu sampling test PASSED\n"); +} + +static void test_backend_logit_bias_sampling(const test_params & params) { + const auto * model = params.model.get(); + const auto * vocab = llama_model_get_vocab(model); + + const int seq_id = 0; + + std::vector<llama_logit_bias> logit_bias; + + // Get the token for the piece "World". + const std::string piece = "World"; + std::vector<llama_token> tokens(16); + llama_tokenize(vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + + llama_token bias_token = tokens[0]; + // TODO: biasing too much here makes the Vulkan sampling fail - should be investigated further + // https://github.com/ggml-org/llama.cpp/actions/runs/20894267644/job/60030252675?pr=18753#step:3:23350 + //logit_bias.push_back({ bias_token, +100.0f }); + logit_bias.push_back({ bias_token, +10.0f }); + + printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + logit_bias.size(), + logit_bias.data())); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id)); + printf("sampled token = %d, expected = %d\n", backend_token, bias_token); + GGML_ASSERT(backend_token == bias_token); + + printf("backend logit bias sampling test PASSED\n"); +} + +// This test verifies that it is possible to have two different backend sampler, +// one that uses the backend dist sampler, and another that uses CPU dist sampler. +static void test_backend_mixed_sampling(const test_params & params) { + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); + + int k = 40; + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } + }; + + test_context test_ctx(params, backend_sampler_configs); + + std::map<llama_seq_id, std::string> prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 that used the dist backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr); + //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0); + } + + // Verfiy sequence 1 that used the top-k backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(logits != nullptr); + size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == (size_t) k); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL); + } + + printf("backend mixed sampling test PASSED\n"); +} + +static void test_backend_set_sampler(const test_params & params) { + const int seq_id = 0; + const int32_t seed = 88; + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using backend sampler configured above + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + + // Now clear the backend sampler for this sequence. + llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr); + printf("Cleared backend sampler for seq_id %d\n", seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, }; + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Should not have any sampled token or probs after clearing the backend sampler. + const int32_t idx = test_ctx.idx_for_seq(seq_id); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL); + GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr); + + // Sample the token using the CPU sampler chain. + llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id); + const std::string token2_str = test_ctx.token_to_piece(token2, false); + printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); + std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, }; + + // Set a new backend sampler for the sequence. + struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get()); + + if (!test_ctx.decode_tokens(tokens2)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id)); + const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); + printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); + + printf("backend set sampler test PASSED\n"); +} + +static void test_backend_cpu_mixed_batch(const test_params & params) { + // Sequence 0 uses backend sampling + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); + + std::vector<llama_sampler_seq_config> backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + }; + + // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling + test_context test_ctx(params, backend_sampler_configs, 2); + + std::map<llama_seq_id, std::string> prompts = { + {0, "Hello"}, // Will use backend sampling + {1, "Some"} // Will use CPU sampling + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 (backend sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 (CPU sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); + + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Clear/remove the backend sampler, and sample again + { + // clear the backend sampler for seq 0 so that there are no backend + // samplers. + llama_set_sampler(test_ctx.ctx.get(), 0, nullptr); + + // Create a CPU sampler and verify we can sampler from it. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); + + int32_t batch_idx = test_ctx.idx_for_seq(1); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + if (!test_ctx.decode_token(token, 1)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + // Set a backend sampler so that we can verify that it can be reset + { + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88)); + + llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get()); + + if (!test_ctx.decode_token(3834, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + printf("backend-cpu mixed batch test PASSED\n"); +} + +static void test_backend_max_outputs(const test_params & params) { + const int seq_id = 0; + const int32_t seed = 88; + + llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + test_context test_ctx(params, backend_sampler_configs); + + llama_batch batch = llama_batch_init(512, 0, 1); + std::string prompt = "Hello"; + + std::vector<llama_token> tokens; + tokens.push_back(llama_vocab_bos(test_ctx.vocab)); + + std::vector<llama_token> prompt_tokens(32); + int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + // set all tokens as output to trigger error + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } + + printf(">>> test_max_outputs expected error start:\n"); + const int ret = llama_decode(test_ctx.ctx.get(), batch); + GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); + printf("<<< test_max_outputs expected error end.\n"); + llama_batch_free(batch); + + printf("backend max outputs test PASSED\n"); +} + +struct backend_test_case { + std::string name; + void (*fn)(const test_params &); + bool enabled_by_default; +}; + +static const backend_test_case BACKEND_TESTS[] = { + { "greedy", test_backend_greedy_sampling, true }, + { "logit_bias", test_backend_logit_bias_sampling, true }, + { "temp", test_backend_temp_sampling, true }, + { "temp_ext", test_backend_temp_ext_sampling, true }, + { "top_k", test_backend_top_k_sampling, true }, + { "multi_sequence", test_backend_multi_sequence_sampling, true }, + { "dist", test_backend_dist_sampling, true }, + { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, + { "set_sampler", test_backend_set_sampler, true }, + { "max_outputs", test_backend_max_outputs, true }, + { "mixed", test_backend_mixed_sampling, true }, + { "min_p", test_backend_min_p_sampling, true }, + { "cpu_mixed", test_backend_cpu_mixed_batch, true }, + { "top_p", test_backend_top_p_sampling, true }, +}; + +static test_args parse_cli(int argc, char ** argv) { + test_args out; + + for (int i = 1; i < argc; ++i) { + const char * arg = argv[i]; + + if (std::strcmp(arg, "--test") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--test expects a value\n"); + exit(EXIT_FAILURE); + } + out.test = argv[++i]; + continue; + } + if (std::strncmp(arg, "--test=", 7) == 0) { + out.test = arg + 7; + continue; + } + if (std::strcmp(arg, "--model") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--model expects a value\n"); + exit(EXIT_FAILURE); + } + out.model = argv[++i]; + continue; + } + if (std::strncmp(arg, "--model=", 8) == 0) { + out.model = arg + 8; + continue; + } + if (std::strcmp(arg, "--device") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--device expects a value (cpu or gpu)\n"); + exit(EXIT_FAILURE); + } + out.device = argv[++i]; + continue; + } + if (std::strncmp(arg, "--device=", 9) == 0) { + out.device = arg + 9; + continue; + } + if (out.model.empty()) { + out.model = arg; + continue; + } + + fprintf(stderr, "Unexpected argument: %s\n", arg); + exit(EXIT_FAILURE); + } + + if (out.device != "cpu" && out.device != "gpu" && out.device != "auto") { + fprintf(stderr, "Invalid device '%s'. Must be 'cpu', 'gpu' or 'auto'\n", out.device.c_str()); + exit(EXIT_FAILURE); + } + + return out; +} + +static std::vector<const backend_test_case *> collect_tests_to_run(const std::string & requested) { + std::vector<const backend_test_case *> selected; + + if (!requested.empty()) { + for (const auto & test : BACKEND_TESTS) { + if (test.name == requested) { + selected.push_back(&test); + break; + } + } + if (selected.empty()) { + fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested.c_str()); + for (const auto & test : BACKEND_TESTS) { + fprintf(stderr, " %s\n", test.name.c_str()); + } + exit(EXIT_FAILURE); + } + } else { + for (const auto & test : BACKEND_TESTS) { + if (test.enabled_by_default) { + selected.push_back(&test); + } + } + } + + if (selected.empty()) { + fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n"); + } + + return selected; +} + +static void run_tests(const std::vector<const backend_test_case *> & tests, const test_params & args) { + for (const auto & test : tests) { + fprintf(stderr, "\n=== %s ===\n", test->name.c_str()); + try { + test->fn(args); + } catch (const std::exception & e) { + fprintf(stderr, "Error running test '%s': %s\n", test->name.c_str(), e.what()); + exit(EXIT_FAILURE); + } + } +} + +int main(int argc, char ** argv) { + test_args args = parse_cli(argc, argv); + + if (args.model.empty()) { + args.model = get_model_or_exit(1, argv); + } + + { + std::ifstream file(args.model); + if (!file.is_open()) { + fprintf(stderr, "no model '%s' found\n", args.model.c_str()); + return EXIT_FAILURE; + } + } + + fprintf(stderr, "using '%s'\n", args.model.c_str()); + + llama_backend_init(); + + test_params params = { + /*.model =*/ load_model(args), + }; + + const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test); + if (!tests.empty()) { + run_tests(tests, params); + } + + return 0; +} |
