1#include "arg.h"
   2#include "common.h"
   3#include "log.h"
   4#include "llama.h"
   5
   6#include <chrono>
   7#include <algorithm>
   8#include <array>
   9#include <atomic>
  10#include <cmath>
  11#include <cstdio>
  12#include <cstring>
  13#include <ctime>
  14#include <fstream>
  15#include <mutex>
  16#include <random>
  17#include <sstream>
  18#include <thread>
  19#include <vector>
  20
  21#if defined(_MSC_VER)
  22#pragma warning(disable: 4244 4267) // possible loss of data
  23#endif
  24
  25struct results_perplexity {
  26    std::vector<llama_token> tokens;
  27    double                   ppl_value;
  28    std::vector<float>       logits;
  29    std::vector<float>       probs;
  30};
  31
  32struct results_log_softmax {
  33    double log_softmax;
  34    float  logit;
  35    float  prob;
  36};
  37
  38static std::vector<float> softmax(const std::vector<float>& logits) {
  39    std::vector<float> probs(logits.size());
  40    float max_logit = logits[0];
  41    for (float v : logits) {
  42        max_logit = std::max(max_logit, v);
  43    }
  44    double sum_exp = 0.0;
  45    for (size_t i = 0; i < logits.size(); i++) {
  46        // Subtract the maximum logit value from the current logit value for numerical stability
  47        const float logit = logits[i] - max_logit;
  48        const float exp_logit = expf(logit);
  49        sum_exp += exp_logit;
  50        probs[i] = exp_logit;
  51    }
  52    for (size_t i = 0; i < probs.size(); i++) {
  53        probs[i] /= sum_exp;
  54    }
  55    return probs;
  56}
  57
  58static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
  59    float max_logit = logits[0];
  60    for (int i = 1; i < n_vocab; ++i) {
  61        max_logit = std::max(max_logit, logits[i]);
  62    }
  63    double sum_exp = 0.0;
  64    for (int i = 0; i < n_vocab; ++i) {
  65        sum_exp += expf(logits[i] - max_logit);
  66    }
  67    return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
  68}
  69
  70static inline int nearest_int(float fval) {
  71    //assert(fval <= 4194303.f);
  72    float val = fval + 12582912.f;
  73    int i; memcpy(&i, &val, sizeof(int));
  74    return (i & 0x007fffff) - 0x00400000;
  75}
  76
  77static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
  78    float max_logit = logits[0];
  79    float min_logit = logits[0];
  80    for (int i = 1; i < n_vocab; ++i) {
  81        max_logit = std::max(max_logit, logits[i]);
  82        min_logit = std::min(min_logit, logits[i]);
  83    }
  84    min_logit = std::max(min_logit, max_logit - 16);
  85    double sum_exp = 0.0;
  86    for (int i = 0; i < n_vocab; ++i) {
  87        sum_exp += expf(logits[i] - max_logit);
  88    }
  89    const float log_sum_exp = log(sum_exp);
  90    const float min_log_prob = min_logit - max_logit - log_sum_exp;
  91    const float scale = (max_logit - min_logit)/65535.f;
  92    float * d = (float *)log_prob;
  93    d[0] = scale;
  94    d[1] = min_log_prob;
  95    log_prob += 4;
  96    if (scale) {
  97        const float inv_scale = 1/scale;
  98        for (int i = 0; i < n_vocab; ++i) {
  99            log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
 100        }
 101    } else {
 102        std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
 103    }
 104    return max_logit + log_sum_exp - logits[tok];
 105}
 106
 107static void process_logits(
 108    int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
 109    double & nll, double & nll2, float * logit_history, float * prob_history
 110) {
 111    std::mutex mutex;
 112    int counter = 0;
 113    auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
 114        double local_nll  = 0;
 115        double local_nll2 = 0;
 116        while (true) {
 117            std::unique_lock<std::mutex> lock(mutex);
 118            int i = counter++;
 119            if (i >= n_token) {
 120                nll += local_nll; nll2 += local_nll2;
 121                break;
 122            }
 123            lock.unlock();
 124            const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]);
 125            const double v = -results.log_softmax;
 126            local_nll += v;
 127            local_nll2 += v*v;
 128
 129            logit_history[i] = results.logit;
 130            prob_history[i]  = results.prob;
 131        }
 132    };
 133    for (auto & w : workers) {
 134        w = std::thread(compute);
 135    }
 136    compute();
 137    for (auto & w : workers) {
 138        w.join();
 139    }
 140}
 141
 142static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
 143        std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
 144    std::mutex mutex;
 145    const int nv = 2*((n_vocab + 1)/2) + 4;
 146    int counter = 0;
 147    auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
 148        double local_nll  = 0;
 149        double local_nll2 = 0;
 150        while (true) {
 151            std::unique_lock<std::mutex> lock(mutex);
 152            int i = counter++;
 153            if (i >= n_token) {
 154                nll += local_nll; nll2 += local_nll2;
 155                break;
 156            }
 157            lock.unlock();
 158            const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
 159            local_nll += v;
 160            local_nll2 += v*v;
 161        }
 162    };
 163    for (auto & w : workers) {
 164        w = std::thread(compute);
 165    }
 166    compute();
 167    for (auto & w : workers) {
 168        w.join();
 169    }
 170    out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
 171}
 172
 173struct kl_divergence_result {
 174    double sum_nll          = 0.0;
 175    double sum_nll2         = 0.0;
 176    double sum_nll_base     = 0.0;
 177    double sum_nll_base2    = 0.0;
 178    double sum_nll_nll_base = 0.0;
 179    double sum_kld          = 0.0;
 180    double sum_kld2         = 0.0;
 181    double sum_p_diff       = 0.0;
 182    double sum_p_diff2      = 0.0;
 183    double sum_p_diff4      = 0.0;
 184    float  max_p_diff       = 0.0f;
 185    size_t n_same_top       = 0.0;
 186    size_t count            = 0.0;
 187};
 188
 189static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
 190    float max_logit = logits[0];
 191    int imax = 0;
 192    for (int i = 1; i < n_vocab; ++i) {
 193        if (logits[i] > max_logit) {
 194            max_logit = logits[i];
 195            imax = i;
 196        }
 197    }
 198    double sum_exp = 0.0;
 199    for (int i = 0; i < n_vocab; ++i) {
 200        sum_exp += expf(logits[i] - max_logit);
 201    }
 202    const float log_sum_exp = log(sum_exp);
 203    const float * d = (const float *)base_log_prob;
 204    const float scale = d[0];
 205    const float min_log_prob = d[1];
 206    base_log_prob += 4;
 207
 208    const float nll = max_logit + log_sum_exp - logits[tok];
 209    kld.sum_nll  += nll;
 210    kld.sum_nll2 += nll*nll;
 211
 212    const float nll_base = -(scale*base_log_prob[tok] + min_log_prob);
 213    kld.sum_nll_base  += nll_base;
 214    kld.sum_nll_base2 += nll_base*nll_base;
 215
 216    kld.sum_nll_nll_base += nll*nll_base;
 217
 218    max_logit += log_sum_exp;
 219    double sum = 0;
 220    int imax_base = -1;
 221    float p_log_base_max = 0;
 222    for (int i = 0; i < n_vocab; ++i) {
 223        const float p_log_base = scale*base_log_prob[i] + min_log_prob;
 224        if (i == 0 || p_log_base > p_log_base_max) {
 225            p_log_base_max = p_log_base;
 226            imax_base = i;
 227        }
 228        if (p_log_base > -16.f) {
 229            const float p_base = expf(p_log_base);
 230            sum += p_base * (p_log_base - logits[i] + max_logit);
 231        }
 232    }
 233    kld.sum_kld  += sum;
 234    kld.sum_kld2 += sum*sum;
 235    ++kld.count;
 236    if (imax == imax_base) {
 237        ++kld.n_same_top;
 238    }
 239
 240    const float p_base = expf(-nll_base);
 241    const float p = expf(-nll);
 242    const float p_diff = p - p_base;
 243    kld.sum_p_diff  += p_diff;
 244    const double p_diff2 = p_diff*p_diff;
 245    kld.sum_p_diff2 += p_diff2;
 246    kld.sum_p_diff4 += p_diff2*p_diff2;
 247    kld.max_p_diff = std::max(kld.max_p_diff, std::fabs(p_diff));
 248
 249    return std::make_pair(sum, p_diff);
 250}
 251
 252static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
 253        std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
 254        float * kld_values, float * p_diff_values) {
 255    std::mutex mutex;
 256    const int nv = 2*((n_vocab + 1)/2) + 4;
 257    int counter = 0;
 258    auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () {
 259        kl_divergence_result local_kld;
 260        while (true) {
 261            std::unique_lock<std::mutex> lock(mutex);
 262            int i = counter++;
 263            if (i >= n_token) {
 264                kld.sum_nll          += local_kld.sum_nll;
 265                kld.sum_nll2         += local_kld.sum_nll2;
 266                kld.sum_nll_base     += local_kld.sum_nll_base;
 267                kld.sum_nll_base2    += local_kld.sum_nll_base2;
 268                kld.sum_nll_nll_base += local_kld.sum_nll_nll_base;
 269                kld.sum_kld          += local_kld.sum_kld;
 270                kld.sum_kld2         += local_kld.sum_kld2;
 271                kld.sum_p_diff       += local_kld.sum_p_diff;
 272                kld.sum_p_diff2      += local_kld.sum_p_diff2;
 273                kld.sum_p_diff4      += local_kld.sum_p_diff4;
 274                kld.n_same_top       += local_kld.n_same_top;
 275                kld.max_p_diff        = std::max(kld.max_p_diff, local_kld.max_p_diff);
 276                kld.count            += local_kld.count;
 277                break;
 278            }
 279            lock.unlock();
 280            std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
 281            kld_values[i]    = (float)v.first;
 282            p_diff_values[i] = v.second;
 283        }
 284    };
 285    for (auto & w : workers) {
 286        w = std::thread(compute);
 287    }
 288    compute();
 289    for (auto & w : workers) {
 290        w.join();
 291    }
 292}
 293
 294static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) {
 295    // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
 296    // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
 297    // Output: `perplexity: 13.5106 [114/114]`
 298    // BOS tokens will be added for each chunk before eval
 299
 300    const llama_model * model = llama_get_model(ctx);
 301    const llama_vocab * vocab = llama_model_get_vocab(model);
 302
 303    const bool add_bos = llama_vocab_get_add_bos(vocab);
 304    GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
 305
 306    LOG_INF("%s: tokenizing the input ..\n", __func__);
 307
 308    std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
 309
 310    const int n_ctx = llama_n_ctx(ctx);
 311
 312    if (int(tokens.size()) < 2*n_ctx) {
 313        LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
 314                n_ctx);
 315        LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
 316        return {std::move(tokens), 0., {}, {}};
 317    }
 318
 319    std::vector<float> logit_history;
 320    std::vector<float> prob_history;
 321
 322    logit_history.resize(tokens.size());
 323    prob_history.resize(tokens.size());
 324
 325    if (params.ppl_stride <= 0) {
 326        LOG_ERR("%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
 327        return {tokens, -1, logit_history, prob_history};
 328    }
 329
 330    const int calc_chunk = n_ctx;
 331
 332    LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
 333
 334    if (int(tokens.size()) <= calc_chunk) {
 335        LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
 336                tokens.size(), n_ctx, params.ppl_stride);
 337        return {tokens, -1, logit_history, prob_history};
 338    }
 339
 340    const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1)  / params.ppl_stride;
 341
 342    const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
 343    const int n_batch = params.n_batch;
 344
 345    const int n_vocab = llama_vocab_n_tokens(vocab);
 346
 347    int count = 0;
 348    double nll = 0.0;
 349
 350    LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
 351
 352    for (int i = 0; i < n_chunk; ++i) {
 353        const int start =     i * params.ppl_stride;
 354        const int end   = start + calc_chunk;
 355
 356        const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
 357        //LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
 358
 359        std::vector<float> logits;
 360
 361        const auto t_start = std::chrono::high_resolution_clock::now();
 362
 363        // clear the KV cache
 364        llama_memory_clear(llama_get_memory(ctx), true);
 365
 366        llama_batch batch = llama_batch_init(n_batch, 0, 1);
 367
 368        for (int j = 0; j < num_batches; ++j) {
 369            const int batch_start = start + j * n_batch;
 370            const int batch_size  = std::min(end - batch_start, n_batch);
 371
 372            common_batch_clear(batch);
 373            for (int i = 0; i < batch_size; i++) {
 374                common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
 375            }
 376
 377            //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
 378            if (llama_decode(ctx, batch)) {
 379                //LOG_ERR("%s : failed to eval\n", __func__);
 380                llama_batch_free(batch);
 381                return {tokens, -1, logit_history, prob_history};
 382            }
 383
 384            // save original token and restore it after eval
 385            const auto token_org = tokens[batch_start];
 386
 387            // add BOS token for the first batch of each chunk
 388            if (add_bos && j == 0) {
 389                tokens[batch_start] = llama_vocab_bos(vocab);
 390            }
 391
 392            const auto * batch_logits = llama_get_logits(ctx);
 393            logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
 394
 395            if (j == 0) {
 396                tokens[batch_start] = token_org;
 397            }
 398        }
 399
 400        llama_batch_free(batch);
 401
 402        const auto t_end = std::chrono::high_resolution_clock::now();
 403
 404        if (i == 0) {
 405            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
 406            LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
 407            int total_seconds = (int)(t_total * n_chunk);
 408            if (total_seconds >= 60*60) {
 409                LOG("%d hours ", total_seconds / (60*60));
 410                total_seconds = total_seconds % (60*60);
 411            }
 412            LOG("%.2f minutes\n", total_seconds / 60.0);
 413        }
 414
 415        //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
 416        for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
 417            // Calculate probability of next token, given the previous ones.
 418            const std::vector<float> tok_logits(
 419                logits.begin() + size_t(j + 0) * n_vocab,
 420                logits.begin() + size_t(j + 1) * n_vocab);
 421
 422            const float prob = softmax(tok_logits)[tokens[start + j + 1]];
 423            logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
 424            prob_history[start + j + 1]  = prob;
 425
 426            nll += -std::log(prob);
 427            ++count;
 428        }
 429        // perplexity is e^(average negative log-likelihood)
 430        if (params.ppl_output_type == 0) {
 431            LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
 432        } else {
 433            LOG("%8d  %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
 434        }
 435    }
 436    LOG("\n");
 437
 438    return {tokens, std::exp(nll / count), logit_history, prob_history};
 439}
 440
 441static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
 442    if (params.ppl_stride > 0) {
 443        return perplexity_v2(ctx, params);
 444    }
 445
 446    // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
 447    // Run `./llama-perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
 448    // Output: `perplexity: 13.5106 [114/114]`
 449    // BOS tokens will be added for each chunk before eval
 450
 451    const llama_model * model = llama_get_model(ctx);
 452    const llama_vocab * vocab = llama_model_get_vocab(model);
 453
 454    const bool add_bos = llama_vocab_get_add_bos(vocab);
 455    GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
 456
 457    std::ofstream logits_stream;
 458    if (!params.logits_file.empty()) {
 459        logits_stream.open(params.logits_file.c_str(), std::ios::binary);
 460        if (!logits_stream.is_open()) {
 461            LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
 462            return {};
 463        }
 464        LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
 465        logits_stream.write("_logits_", 8);
 466        logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
 467    }
 468
 469    auto tim1 = std::chrono::high_resolution_clock::now();
 470    LOG_INF("%s: tokenizing the input ..\n", __func__);
 471
 472    std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
 473
 474    auto tim2 = std::chrono::high_resolution_clock::now();
 475    LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
 476
 477    if (int(tokens.size()) < 2*n_ctx) {
 478        LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
 479                n_ctx);
 480        LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
 481        return {std::move(tokens), 0., {}, {}};
 482    }
 483
 484    std::vector<float> logit_history;
 485    logit_history.resize(tokens.size());
 486
 487    std::vector<float> prob_history;
 488    prob_history.resize(tokens.size());
 489
 490    const int n_chunk_max = tokens.size() / n_ctx;
 491
 492    const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
 493    const int n_batch = params.n_batch;
 494
 495    const int n_vocab = llama_vocab_n_tokens(vocab);
 496
 497    int count = 0;
 498    double nll = 0.0;
 499    double nll2 = 0.0;
 500
 501    const int num_batches = (n_ctx + n_batch - 1) / n_batch;
 502    const int n_seq = std::max(1, n_batch / n_ctx);
 503
 504    GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
 505    GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
 506
 507    llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
 508
 509    std::vector<float> logits;
 510    if (num_batches > 1) {
 511        logits.reserve(size_t(n_ctx) * n_vocab);
 512    }
 513
 514    LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
 515
 516    std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
 517
 518    std::vector<uint16_t> log_probs;
 519    if (!params.logits_file.empty()) {
 520        logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
 521        logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
 522        logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
 523        const int nv = 2*((n_vocab + 1)/2) + 4;
 524        log_probs.resize(n_ctx * nv);
 525    }
 526
 527    // We get the logits for all the tokens in the context window (params.n_ctx)
 528    // from llama_decode below.  Now, based on https://huggingface.co/docs/transformers/perplexity,
 529    // calculate the perplexity over the last half of the window (so the model always has
 530    // some context to predict the token).
 531    //
 532    // We rely on the fact that attention in the forward pass only looks at previous
 533    // tokens here, so the logits returned for each token are an accurate representation
 534    // of what the model would have predicted at that point.
 535    //
 536    // Example, we have a context window of 512, we will compute perplexity for each of the
 537    // last 256 tokens.  Then, we split the input up into context window size chunks to
 538    // process the entire prompt.
 539    const int first = n_ctx/2;
 540
 541    for (int i = 0; i < n_chunk; i += n_seq) {
 542        const int start =     i * n_ctx;
 543        const int end   = start + n_ctx;
 544
 545        const int n_seq_batch = std::min(n_seq, n_chunk - i);
 546
 547        const auto t_start = std::chrono::high_resolution_clock::now();
 548
 549        // clear the KV cache
 550        llama_memory_clear(llama_get_memory(ctx), true);
 551
 552        for (int j = 0; j < num_batches; ++j) {
 553            const int batch_start = start + j * n_batch;
 554            const int batch_size  = std::min(end - batch_start, n_batch);
 555
 556            int n_outputs = 0;
 557
 558            batch.n_tokens = 0;
 559            for (int seq = 0; seq < n_seq_batch; seq++) {
 560                int seq_start = batch_start + seq*n_ctx;
 561
 562                // save original token and restore it after decode
 563                const auto token_org = tokens[seq_start];
 564
 565                // add BOS token for the first batch of each chunk
 566                if (add_bos && j == 0) {
 567                    tokens[seq_start] = llama_vocab_bos(vocab);
 568                }
 569
 570                for (int k = 0; k < batch_size; ++k) {
 571                    const int idx = seq*n_ctx + k;
 572                    batch.token   [idx]    = tokens[seq_start + k];
 573                    batch.pos     [idx]    = j*n_batch + k;
 574                    batch.n_seq_id[idx]    = 1;
 575                    batch.seq_id  [idx][0] = seq;
 576                    batch.logits  [idx]    = batch.pos[idx] >= first ? 1 : 0;
 577
 578                    n_outputs += batch.logits[idx] != 0;
 579                }
 580                batch.n_tokens += batch_size;
 581
 582                // restore the original token in case it was set to BOS
 583                tokens[seq_start] = token_org;
 584            }
 585
 586            if (llama_decode(ctx, batch)) {
 587                LOG_INF("%s : failed to decode\n", __func__);
 588                return {tokens, -1, logit_history, prob_history};
 589            }
 590
 591            if (num_batches > 1 && n_outputs > 0) {
 592                const auto * batch_logits = llama_get_logits(ctx);
 593                logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
 594            }
 595        }
 596
 597
 598        if (i == 0) {
 599            llama_synchronize(ctx);
 600            const auto t_end = std::chrono::high_resolution_clock::now();
 601            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
 602            LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
 603            int total_seconds = (int)(t_total*n_chunk/n_seq);
 604            if (total_seconds >= 60*60) {
 605                LOG("%d hours ", total_seconds / (60*60));
 606                total_seconds = total_seconds % (60*60);
 607            }
 608            LOG("%.2f minutes\n", total_seconds / 60.0);
 609        }
 610
 611        for (int seq = 0; seq < n_seq_batch; seq++) {
 612            const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
 613
 614            llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
 615            if (!params.logits_file.empty()) {
 616                process_logits(logits_stream, n_vocab, all_logits,
 617                        tokens_data, n_ctx - 1 - first,
 618                        workers, log_probs, nll, nll2);
 619            } else {
 620                process_logits(n_vocab, all_logits,
 621                        tokens_data, n_ctx - 1 - first,
 622                        workers, nll, nll2,
 623                        logit_history.data() + start + seq*n_ctx + first,
 624                        prob_history.data()  + start + seq*n_ctx + first);
 625            }
 626            count += n_ctx - first - 1;
 627
 628            // perplexity is e^(average negative log-likelihood)
 629            if (params.ppl_output_type == 0) {
 630                LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
 631            } else {
 632                double av = nll/count;
 633                double av2 = nll2/count - av*av;
 634                if (av2 > 0) {
 635                    av2 = sqrt(av2/(count-1));
 636                }
 637                LOG("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
 638            }
 639        }
 640
 641        logits.clear();
 642    }
 643    LOG("\n");
 644
 645    nll2 /= count;
 646    nll /= count;
 647    const double ppl = exp(nll);
 648    nll2 -= nll * nll;
 649    if (nll2 > 0) {
 650        nll2 = sqrt(nll2/(count-1));
 651        LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
 652    } else {
 653        LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
 654    }
 655
 656    llama_batch_free(batch);
 657
 658    return {tokens, ppl, logit_history, prob_history};
 659}
 660
 661static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
 662    int prev_outputs = 0;
 663    for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
 664        const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
 665
 666        llama_batch batch_view = {
 667            n_tokens,
 668            batch.token    + i,
 669            nullptr,
 670            batch.pos      + i,
 671            batch.n_seq_id + i,
 672            batch.seq_id   + i,
 673            batch.logits   + i,
 674        };
 675
 676        const int ret = llama_decode(ctx, batch_view);
 677        if (ret != 0) {
 678            LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
 679            return false;
 680        }
 681
 682        int n_outputs = 0;
 683        for (int i = 0; i < n_tokens; ++i) {
 684            n_outputs += batch_view.logits[i] != 0;
 685        }
 686
 687        memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
 688
 689        prev_outputs += n_outputs;
 690    }
 691
 692    return true;
 693}
 694
 695#define K_TOKEN_CHUNK 4
 696
 697static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
 698        const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
 699    if (eval_results.size() != eval_pairs.size()) {
 700        eval_results.resize(eval_pairs.size());
 701    }
 702    if (eval_pairs.empty()) {
 703        return;
 704    }
 705
 706    size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
 707
 708    std::atomic<int> counter(0);
 709    auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
 710        float local_logprobs[K_TOKEN_CHUNK];
 711        while (true) {
 712            const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
 713            if (first >= eval_results.size()) {
 714                break;
 715            }
 716            const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
 717            for (size_t i = first; i < last; ++i) {
 718                const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
 719                float max_logit = logits[0];
 720                for (int j = 1; j < n_vocab; ++j) {
 721                    max_logit = std::max(max_logit, logits[j]);
 722                }
 723                float sum_p = 0.f;
 724                for (int j = 0; j < n_vocab; ++j) {
 725                    sum_p += expf(logits[j] - max_logit);
 726                }
 727                local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
 728            }
 729            std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
 730        }
 731    };
 732
 733    for (size_t it = 0; it < max_threads; ++it) {
 734        workers[it] = std::thread(compute);
 735    }
 736    for (size_t it = 0; it < max_threads; ++it) {
 737        workers[it].join();
 738    }
 739}
 740
 741static void hellaswag_score(llama_context * ctx, const common_params & params) {
 742    const llama_model * model = llama_get_model(ctx);
 743    const llama_vocab * vocab = llama_model_get_vocab(model);
 744
 745    // Calculates hellaswag score (acc_norm) from prompt
 746    //
 747    // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
 748    // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
 749    //
 750    // All 10042 tasks should be extracted to keep the results standardized like other implementations.
 751    //
 752    // Datafile layout:
 753    // ['??'] denotes json fields
 754    // 6 lines per task:
 755    // ['activity_label'] + ": " +['ctx']  - The first part of the query, the context
 756    // ['label'] - The index the best common sense ending aka gold ending
 757    // ['endings'][0] - Endings added to the first part of the query
 758    // ['endings'][1]
 759    // ['endings'][2]
 760    // ['endings'][3]
 761
 762    std::vector<std::string> prompt_lines;
 763    std::istringstream strstream(params.prompt);
 764    std::string line;
 765
 766    while (std::getline(strstream,line,'\n')) {
 767        prompt_lines.push_back(line);
 768    }
 769
 770    if (prompt_lines.size() % 6 != 0) {
 771        LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n", __func__);
 772        return;
 773    }
 774
 775    size_t hs_task_count = prompt_lines.size()/6;
 776    LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
 777
 778    const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
 779    LOG_INF("================================= is_spm = %d\n", is_spm);
 780
 781    // The tasks should be randomized so the score stabilizes quickly.
 782    bool randomize_tasks = true;
 783
 784    // Number of tasks to use when computing the score
 785    if (params.hellaswag_tasks < hs_task_count) {
 786        hs_task_count = params.hellaswag_tasks;
 787    }
 788
 789    // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
 790    std::mt19937 rng(1);
 791
 792    // Dataholder for hellaswag tasks
 793    struct hs_data_t {
 794        std::string context;
 795        size_t gold_ending_idx;
 796        std::string ending[4];
 797        size_t ending_logprob_count[4];
 798        double ending_logprob[4];
 799
 800        size_t i_logits;        // starting index of logits in the llama_batch
 801        size_t common_prefix;   // max number of initial tokens that are the same in all sentences
 802        size_t required_tokens; // needed number of tokens to evaluate all 4 endings
 803        std::vector<llama_token> seq_tokens[4];
 804    };
 805
 806    LOG_INF("%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first")  );
 807
 808    // Select and read data from prompt lines
 809    std::vector<hs_data_t> hs_data(hs_task_count);
 810    for (size_t i = 0; i < hs_task_count; i++) {
 811        size_t idx = i;
 812
 813        auto & hs_cur = hs_data[i];
 814
 815        // Select a random example of those left in the prompt
 816        if (randomize_tasks) {
 817            std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
 818            idx = dist(rng);
 819        }
 820
 821        hs_cur.context = prompt_lines[idx*6];
 822        hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
 823        for (size_t j = 0; j < 4; j++) {
 824            hs_cur.ending[j] = prompt_lines[idx*6+2+j];
 825            hs_cur.seq_tokens[j] = common_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
 826        }
 827
 828        // determine the common prefix of the endings
 829        hs_cur.common_prefix = 0;
 830        for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
 831            if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
 832                hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
 833                hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
 834                break;
 835            }
 836            hs_cur.common_prefix++;
 837        }
 838        hs_cur.required_tokens = hs_cur.common_prefix +
 839            hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
 840            hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
 841            hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
 842            hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
 843
 844        //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
 845
 846        // Delete the selected random example from the prompt
 847        if (randomize_tasks) {
 848            prompt_lines.erase( std::next(prompt_lines.begin(),idx*6)  , std::next(prompt_lines.begin(),idx*6+6) );
 849        }
 850    }
 851
 852    LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
 853
 854    LOG("\ntask\tacc_norm\t95%% confidence interval\n");
 855
 856    double acc = 0.0f;
 857
 858    const int n_ctx   = llama_n_ctx(ctx);
 859    const int n_batch = params.n_batch;
 860
 861    const int n_vocab = llama_vocab_n_tokens(vocab);
 862
 863    const int max_tasks_per_batch = 32;
 864    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
 865
 866    llama_batch batch = llama_batch_init(n_ctx, 0, 4);
 867
 868    std::vector<float> tok_logits(n_vocab);
 869    // TODO: this could be made smaller; it's currently the worst-case size
 870    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
 871
 872    std::vector<std::pair<size_t, llama_token>> eval_pairs;
 873    std::vector<float> eval_results;
 874    std::vector<std::thread> workers(std::thread::hardware_concurrency());
 875
 876    for (size_t i0 = 0; i0 < hs_task_count; i0++) {
 877        int n_cur = 0;
 878
 879        size_t i1 = i0;
 880        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
 881
 882        common_batch_clear(batch);
 883
 884        // batch as much tasks as possible into the available context
 885        // each task has 4 unique sequence ids - one for each ending
 886        // the common prefix is shared among the 4 sequences to save tokens
 887        // we extract logits only from the last common token and from all ending tokens of each sequence
 888        while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
 889            auto & hs_cur = hs_data[i1];
 890            int n_logits = 0;
 891
 892            const int s0 = 4*(i1 - i0);
 893            if (s0 + 4 > max_seq) {
 894                break;
 895            }
 896
 897            for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
 898                common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
 899            }
 900            batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
 901            n_logits += 1;
 902
 903            for (int s = 0; s < 4; ++s) {
 904                const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
 905                // TODO: don't evaluate the last token of each sequence
 906                for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
 907                    const bool needs_logits = i < seq_tokens_size - 1;
 908                    common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
 909                    n_logits += needs_logits;
 910                }
 911            }
 912
 913            hs_cur.i_logits = i_logits;
 914            i_logits += n_logits;
 915
 916            n_cur += hs_data[i1].required_tokens;
 917            if (++i1 == hs_task_count) {
 918                break;
 919            }
 920        }
 921
 922        if (i0 == i1) {
 923            LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, hs_data[i0].required_tokens);
 924            return;
 925        }
 926
 927        llama_memory_clear(llama_get_memory(ctx), true);
 928
 929        // decode all tasks [i0, i1)
 930        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
 931            LOG_ERR("%s: llama_decode() failed\n", __func__);
 932            return;
 933        }
 934
 935        // Compute log-probs in parallel
 936        // First we collect all tasks
 937        eval_pairs.clear();
 938        for (size_t i = i0; i < i1; ++i) {
 939            auto & hs_cur = hs_data[i];
 940            size_t li = 1; // skip the last logit of the common prefix (computed separately below)
 941            for (int s = 0; s < 4; ++s) {
 942                for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
 943                    eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
 944                }
 945            }
 946        }
 947        // Then we do the actual calculation
 948        compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
 949
 950        size_t ir = 0;
 951
 952        // compute the logprobs for each ending of the decoded tasks
 953        for (size_t i = i0; i < i1; ++i) {
 954            auto & hs_cur = hs_data[i];
 955
 956            // get the logits of the last token of the common prefix
 957            std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float));
 958
 959            const auto first_probs = softmax(tok_logits);
 960
 961            for (int s = 0; s < 4; ++s) {
 962                hs_cur.ending_logprob_count[s] = 1;
 963                hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
 964                for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
 965                    hs_cur.ending_logprob[s] += eval_results[ir++];
 966                    hs_cur.ending_logprob_count[s]++;
 967                }
 968                hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
 969            }
 970
 971            // Find the ending with maximum logprob
 972            size_t ending_logprob_max_idx = 0;
 973            double ending_logprob_max_val = hs_cur.ending_logprob[0];
 974            for (size_t s = 1; s < 4; s++) {
 975                if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
 976                    ending_logprob_max_idx = s;
 977                    ending_logprob_max_val =  hs_cur.ending_logprob[s];
 978                }
 979            }
 980
 981            //LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
 982
 983            // If the gold ending got the maximum logprobe add one accuracy point
 984            if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
 985                acc += 1.0;
 986            }
 987
 988            double freq = acc / double(i + 1);
 989
 990            const double za = 1.95996398454;
 991
 992            // // Wald normal approx
 993            // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
 994            // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
 995
 996            // Wilson score interval, more accurate
 997            double z   = za * za / double(i + 1);
 998            double cnf = z * sqrt(double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
 999            double a   = (freq + z * 0.5 - cnf) / (1.0 + z);
1000            double b   = (freq + z * 0.5 + cnf) / (1.0 + z);
1001
1002            // Print the accumulated accuracy mean x 100 and confidence interval
1003            LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
1004        }
1005
1006        i0 = i1 - 1;
1007    }
1008
1009    llama_batch_free(batch);
1010
1011    LOG("\n");
1012}
1013
1014struct winogrande_entry {
1015    std::string first;
1016    std::string second;
1017    std::array<std::string, 2> choices;
1018    int answer;
1019
1020    size_t i_logits;
1021    size_t common_prefix;
1022    size_t required_tokens;
1023    size_t n_base1; // number of tokens for context + choice 1
1024    size_t n_base2; // number of tokens for context + choice 2
1025    std::vector<llama_token> seq_tokens[2];
1026};
1027
1028static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
1029    std::vector<winogrande_entry> result;
1030    std::istringstream in(prompt);
1031    std::string line;
1032    std::array<int, 4> comma_pos;
1033    while (true) {
1034        std::getline(in, line);
1035        if (in.fail() || in.eof()) break;
1036        int ipos = 0;
1037        bool quote_open = false;
1038        for (int i = 0; i < int(line.size()); ++i) {
1039            if (!quote_open) {
1040                if (line[i] == ',') {
1041                    comma_pos[ipos++] = i;
1042                    if (ipos == 4) break;
1043                }
1044                else if (line[i] == '"') {
1045                    quote_open = true;
1046                }
1047            }
1048            else {
1049                if (line[i] == '"') {
1050                    quote_open = false;
1051                }
1052            }
1053        }
1054        if (ipos != 4) {
1055            LOG_ERR("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
1056            continue;
1057        }
1058        auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
1059                                                    : line.substr(comma_pos[0]+1, comma_pos[1] - comma_pos[0] - 1);
1060        auto choice1 = line.substr(comma_pos[1]+1, comma_pos[2] - comma_pos[1] - 1);
1061        auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
1062        auto answer  = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
1063        auto index = line.substr(0, comma_pos[0]);
1064        int where = 0;
1065        for ( ; where < int(sentence.size()); ++where) {
1066            if (sentence[where] == '_') break;
1067        }
1068        if (where == int(sentence.size())) {
1069            LOG_ERR("%s: no _ in <%s>\n", __func__, sentence.c_str());
1070            continue;
1071        }
1072        std::istringstream stream(answer.c_str());
1073        int i_answer; stream >> i_answer;
1074        if (stream.fail() || i_answer < 1 || i_answer > 2) {
1075            LOG_ERR("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
1076            continue;
1077        }
1078        result.emplace_back();
1079        auto& wg = result.back();
1080        wg.first = sentence.substr(0, where);
1081        wg.second = sentence.substr(where + 1, sentence.size() - where - 1);
1082        wg.choices[0] = std::move(choice1);
1083        wg.choices[1] = std::move(choice2);
1084        wg.answer = i_answer;
1085    }
1086    return result;
1087}
1088
1089/*
1090 * Evaluates the Winogrande score.
1091 * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2)
1092 * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp
1093 * As an example, the 1st row in the above dataset is
1094 *
1095 *    0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
1096 *
1097 */
1098static void winogrande_score(llama_context * ctx, const common_params & params) {
1099    const llama_model * model = llama_get_model(ctx);
1100    const llama_vocab * vocab = llama_model_get_vocab(model);
1101
1102    constexpr int k_min_trailing_ctx = 3;
1103
1104    auto data = load_winogrande_from_csv(params.prompt);
1105    if (data.empty()) {
1106        LOG_ERR("%s: no tasks\n", __func__);
1107        return;
1108    }
1109
1110    LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, data.size());
1111
1112    if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
1113        LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
1114        std::mt19937 rng(1);
1115        std::vector<int> aux(data.size());
1116        for (int i = 0; i < int(data.size()); ++i) {
1117            aux[i] = i;
1118        }
1119        float scale = 1/(1.f + (float)rng.max());
1120        std::vector<winogrande_entry> selected;
1121        selected.resize(params.winogrande_tasks);
1122        for (int i = 0; i < int(params.winogrande_tasks); ++i) {
1123            int j = int(scale*rng()*aux.size());
1124            selected[i] = std::move(data[aux[j]]);
1125            aux[j] = aux.back();
1126            aux.pop_back();
1127        }
1128        data = std::move(selected);
1129    }
1130
1131    LOG_INF("%s : tokenizing selected tasks\n", __func__);
1132
1133    for (auto & task : data) {
1134        task.seq_tokens[0] = common_tokenize(ctx, task.first + task.choices[0] + task.second, true);
1135        task.seq_tokens[1] = common_tokenize(ctx, task.first + task.choices[1] + task.second, true);
1136
1137        task.common_prefix = 0;
1138        for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
1139            if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
1140                break;
1141            }
1142            task.common_prefix++;
1143        }
1144
1145        // TODO: the last token of each of the sequences don't need to be evaluated
1146        task.required_tokens = task.common_prefix +
1147            task.seq_tokens[0].size() - task.common_prefix +
1148            task.seq_tokens[1].size() - task.common_prefix;
1149
1150        task.n_base1 = common_tokenize(ctx, task.first + task.choices[0], true).size();
1151        task.n_base2 = common_tokenize(ctx, task.first + task.choices[1], true).size();
1152    }
1153
1154    LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
1155
1156    const int n_ctx   = llama_n_ctx(ctx);
1157    const int n_batch = params.n_batch;
1158
1159    const int n_vocab = llama_vocab_n_tokens(vocab);
1160
1161    const int max_tasks_per_batch = 128;
1162    const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1163
1164    llama_batch batch = llama_batch_init(n_ctx, 0, 2);
1165
1166    std::vector<float> tok_logits(n_vocab);
1167    // TODO: this could be made smaller; it's currently the worst-case size
1168    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1169
1170    std::vector<std::pair<size_t, llama_token>> eval_pairs;
1171    std::vector<float> eval_results;
1172    std::vector<std::thread> workers(std::thread::hardware_concurrency());
1173
1174    int n_correct = 0;
1175    int n_done    = 0;
1176
1177    for (size_t i0 = 0; i0 < data.size(); i0++) {
1178        int n_cur = 0;
1179
1180        size_t i1 = i0;
1181        size_t i_logits = 0;
1182
1183        common_batch_clear(batch);
1184
1185        while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
1186            int n_logits = 0;
1187            const int s0 = 2*(i1 - i0);
1188            if (s0 + 2 > max_seq) {
1189                break;
1190            }
1191
1192            for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1193                common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1194            }
1195            batch.logits[batch.n_tokens - 1] = true;
1196            n_logits += 1;
1197
1198            for (int s = 0; s < 2; ++s) {
1199                // TODO: end before the last token, no need to predict past the end of the sequences
1200                for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1201                    common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1202                    n_logits += 1;
1203                }
1204            }
1205
1206            data[i1].i_logits = i_logits;
1207            i_logits += n_logits;
1208
1209            n_cur += data[i1].required_tokens;
1210            if (++i1 == data.size()) {
1211                break;
1212            }
1213        }
1214
1215        if (i0 == i1) {
1216            LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, data[i0].required_tokens);
1217            return;
1218        }
1219
1220        llama_memory_clear(llama_get_memory(ctx), true);
1221
1222        // decode all tasks [i0, i1)
1223        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1224            LOG_ERR("%s: llama_decode() failed\n", __func__);
1225            return;
1226        }
1227
1228        eval_pairs.clear();
1229        for (size_t i = i0; i < i1; ++i) {
1230            auto & task = data[i];
1231
1232            const bool skip_choice =
1233                task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1234                task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1235
1236            const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1237            const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1238            size_t li = n_base1 - task.common_prefix;
1239            for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1240                eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
1241            }
1242            const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1243            const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1244            // FIXME: this uses the wrong first logits when not skipping the choice word
1245            li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
1246            for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1247                eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
1248            }
1249        }
1250        compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1251
1252        size_t ir = 0;
1253        for (size_t i = i0; i < i1; ++i) {
1254            auto & task = data[i];
1255
1256            const bool skip_choice =
1257                task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1258                task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1259
1260            float score_1st = 0;
1261            const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1262            const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1263            for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
1264                score_1st += eval_results[ir++];
1265            }
1266            score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
1267
1268            float score_2nd = 0;
1269            const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1270            const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1271            for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1272                score_2nd += eval_results[ir++];
1273            }
1274            score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
1275
1276            int result = score_1st > score_2nd ? 1 : 2;
1277
1278            if (result == task.answer) {
1279                ++n_correct;
1280            }
1281            ++n_done;
1282
1283            // print the accumulated accuracy mean x 100
1284            LOG("%zu\t%.4lf\t%10.6f  %10.6f  %d  %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1285        }
1286
1287        i0 = i1 - 1;
1288    }
1289
1290    LOG("\n");
1291
1292    if (n_done < 100) return;
1293
1294    const float p = 1.f*n_correct/n_done;
1295    const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
1296
1297    LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
1298}
1299
1300static bool deserialize_string(std::istream & in, std::string & str) {
1301    uint32_t size;
1302    if (!in.read((char *)&size, sizeof(size)).fail()) {
1303        str.resize(size);
1304        if (!in.read((char *)&str[0], size).fail()) return true;
1305    }
1306    return false;
1307}
1308
1309struct multiple_choice_answers {
1310    std::vector<std::string> answers;
1311    std::vector<int>         labels;
1312    bool deserialize(std::istream& in) {
1313        uint32_t n;
1314        in.read((char *)&n, sizeof(n));
1315        if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose
1316        answers.resize(n);
1317        labels.resize(n);
1318        for (auto& a : answers) {
1319            if (!deserialize_string(in, a)) return false;
1320        }
1321        in.read((char *)labels.data(), n*sizeof(int));
1322        return !in.fail();
1323    }
1324};
1325
1326struct multiple_choice_task {
1327    std::string question;         // the question (or context that needs to be continued)
1328    multiple_choice_answers mc1;  // possible answers (continuations) with a single correct answer
1329    multiple_choice_answers mc2;  // possible answers (continuations) with multiple correct answers - not handled yet
1330    bool deserialize(std::istream& in) {
1331        if (!deserialize_string(in, question)) return false;
1332        return mc1.deserialize(in) && mc2.deserialize(in);
1333    }
1334
1335    // For evaluation
1336    size_t i_logits;        // starting index of logits in the llama_batch
1337    size_t common_prefix;   // max number of initial tokens that are the same in all sentences
1338    size_t required_tokens; // needed number of tokens to evaluate all answers
1339    std::vector<std::vector<llama_token>> seq_tokens;
1340    std::vector<float> log_probs;
1341};
1342
1343static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
1344    if (task.question.empty() || task.mc1.answers.empty()) {
1345        if (log_error) {
1346            LOG_ERR("%s: found bad task with empty question and/or answers\n", __func__);
1347        }
1348        return false;
1349    }
1350    task.seq_tokens.reserve(task.mc1.answers.size());
1351    for (auto& answer : task.mc1.answers) {
1352        if (answer.empty()) {
1353            if (log_error) {
1354                LOG_ERR("%s: found empty answer\n", __func__);
1355            }
1356            return false;
1357        }
1358        task.seq_tokens.emplace_back(::common_tokenize(ctx, task.question + " " + answer, true));
1359    }
1360    auto min_len = task.seq_tokens.front().size();
1361    for (auto& seq : task.seq_tokens) {
1362        min_len = std::min(min_len, seq.size());
1363    }
1364    task.common_prefix = 0;
1365    for (size_t k = 0; k < min_len; ++k) {
1366        auto token = task.seq_tokens[0][k];
1367        bool all_same = true;
1368        for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
1369            if (task.seq_tokens[i][k] != token) {
1370                all_same = false;
1371                break;
1372            }
1373        }
1374        if (!all_same) {
1375            break;
1376        }
1377        ++task.common_prefix;
1378    }
1379    task.required_tokens = task.common_prefix;
1380    for (auto& seq : task.seq_tokens) {
1381        task.required_tokens += seq.size() - task.common_prefix;
1382    }
1383    return true;
1384}
1385
1386//
1387// Calculates score for multiple choice tasks with single correct answer from prompt.
1388// Commonly used LLM evaluation metrics of this type are
1389//   * ARC
1390//   * HellaSwag
1391//   * MMLU
1392//   * TruthfulQA
1393//
1394// Validation datasets for these 4 tests can be found at
1395//     https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
1396// The data for these datasets was extracted from
1397//     git@hf.co:datasets/allenai/ai2_arc
1398//     https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
1399//     git@hf.co:datasets/Stevross/mmlu
1400//     https://huggingface.co/datasets/truthful_qa
1401//
1402static void multiple_choice_score(llama_context * ctx, const common_params & params) {
1403    const llama_model * model = llama_get_model(ctx);
1404    const llama_vocab * vocab = llama_model_get_vocab(model);
1405
1406    std::istringstream strstream(params.prompt);
1407    uint32_t n_task;
1408    strstream.read((char *)&n_task, sizeof(n_task));
1409    if (strstream.fail() || n_task == 0) {
1410        LOG_ERR("%s: no tasks\n", __func__);
1411        return;
1412    }
1413    LOG_INF("%s: there are %u tasks in prompt\n", __func__, n_task);
1414    std::vector<uint32_t> task_pos(n_task);
1415    strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
1416    if (strstream.fail()) {
1417        LOG_ERR("%s: failed to read task positions from prompt\n", __func__);
1418        return;
1419    }
1420
1421    std::vector<multiple_choice_task> tasks;
1422    if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
1423        // Use all tasks
1424        tasks.resize(n_task);
1425        LOG_INF("%s: reading tasks", __func__);
1426        int n_dot = std::max((int) n_task/100, 1);
1427        int i = 0;
1428        for (auto& task : tasks) {
1429            ++i;
1430            if (!task.deserialize(strstream)) {
1431                LOG_ERR("%s: failed to read task %d of %u\n", __func__, i, n_task);
1432                return;
1433            }
1434            if (i%n_dot == 0) LOG(".");
1435        }
1436        LOG("done\n");
1437    }
1438    else {
1439        LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
1440        std::mt19937 rng(1);
1441        std::vector<int> aux(n_task);
1442        for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
1443        float scale = 1.f/(1.f + (float)std::mt19937::max());
1444        tasks.resize(params.multiple_choice_tasks);
1445        for (auto& task : tasks) {
1446            int j = (int)(scale * rng() * aux.size());
1447            int idx = aux[j];
1448            aux[j] = aux.back();
1449            aux.pop_back();
1450            strstream.seekg(task_pos[idx], std::ios::beg);
1451            if (!task.deserialize(strstream)) {
1452                LOG_ERR("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
1453                return;
1454            }
1455        }
1456        n_task = params.multiple_choice_tasks;
1457    }
1458
1459    LOG_INF("%s: preparing task data", __func__);
1460    if (n_task > 500) {
1461        LOG("...");
1462        std::atomic<int> counter(0);
1463        std::atomic<int> n_bad(0);
1464        auto prepare = [&counter, &n_bad, &tasks, ctx] () {
1465            int num_tasks = tasks.size();
1466            int n_bad_local = 0;
1467            while (true) {
1468                int first = counter.fetch_add(K_TOKEN_CHUNK);
1469                if (first >= num_tasks) {
1470                    if (n_bad_local > 0) n_bad += n_bad_local;
1471                    break;
1472                }
1473                int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
1474                for (int i = first; i < last; ++i) {
1475                    if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local;
1476                }
1477            }
1478        };
1479        size_t max_thread = std::thread::hardware_concurrency();
1480        max_thread = std::min(max_thread, (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK);
1481        std::vector<std::thread> workers(max_thread-1);
1482        for (auto& w : workers) w = std::thread(prepare);
1483        prepare();
1484        for (auto& w : workers) w.join();
1485        LOG("done\n");
1486        int nbad = n_bad;
1487        if (nbad > 0) {
1488            LOG_ERR("%s: found %d malformed tasks\n", __func__, nbad);
1489            return;
1490        }
1491    } else {
1492        int n_dot = std::max((int) n_task/100, 1);
1493        int i_task = 0;
1494        for (auto& task : tasks) {
1495            ++i_task;
1496            if (!multiple_choice_prepare_one_task(ctx, task, true)) {
1497                return;
1498            }
1499            if (i_task%n_dot == 0) {
1500                LOG(".");
1501            }
1502        }
1503        LOG("done\n");
1504    }
1505
1506    LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
1507
1508    LOG("\ntask\tacc_norm\n");
1509
1510    const int n_ctx   = llama_n_ctx(ctx);
1511    const int n_batch = params.n_batch;
1512
1513    const int n_vocab = llama_vocab_n_tokens(vocab);
1514
1515    const int max_tasks_per_batch = 32;
1516    const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1517
1518    llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
1519
1520    std::vector<float> tok_logits(n_vocab);
1521    std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1522
1523    std::vector<std::pair<size_t, llama_token>> eval_pairs;
1524    std::vector<float> eval_results;
1525    std::vector<std::thread> workers(std::thread::hardware_concurrency());
1526    std::vector<int> batch_indeces;
1527
1528    int n_done = 0;
1529    int n_correct = 0;
1530    int n_tot_answers = 0;
1531
1532    for (size_t i0 = 0; i0 < tasks.size(); i0++) {
1533        int n_cur = 0;
1534
1535        size_t i1 = i0;
1536        size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
1537
1538        common_batch_clear(batch);
1539
1540        // batch as much tasks as possible into the available context
1541        // each task has 4 unique sequence ids - one for each ending
1542        // the common prefix is shared among the 4 sequences to save tokens
1543        // we extract logits only from the last common token and from all ending tokens of each sequence
1544        int s0 = 0;
1545        while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
1546            auto& cur_task = tasks[i1];
1547            int n_logits = 0;
1548
1549            int num_answers = cur_task.seq_tokens.size();
1550            if (s0 + num_answers > max_seq) {
1551                if (s0 == 0) {
1552                    LOG_ERR("%s : task %zu requires a higher -np|--parallel value (at least %d)\n", __func__, i0, num_answers);
1553                    return;
1554                }
1555                break;
1556            }
1557
1558            if (int(batch_indeces.size()) != num_answers) {
1559                batch_indeces.resize(num_answers);
1560            }
1561
1562            for (int s = 0; s < num_answers; ++s) {
1563                batch_indeces[s] = s0 + s;
1564            }
1565
1566            for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1567                //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1568                common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
1569            }
1570            batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1571            n_logits += 1;
1572
1573            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1574                const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
1575                // TODO: don't evaluate the last token of each sequence
1576                for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
1577                    const bool needs_logits = i < seq_tokens_size - 1;
1578                    common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1579                    n_logits += needs_logits;
1580                }
1581            }
1582
1583            s0 += num_answers;
1584
1585            cur_task.i_logits = i_logits;
1586            i_logits += n_logits;
1587
1588            n_cur += cur_task.required_tokens;
1589            if (++i1 == tasks.size()) {
1590                break;
1591            }
1592        }
1593
1594        if (i0 == i1) {
1595            LOG_ERR("%s : task %zu does not fit in the context window (requires %lu tokens)\n", __func__, i0, tasks[i0].required_tokens);
1596            return;
1597        }
1598
1599        llama_memory_clear(llama_get_memory(ctx), true);
1600
1601        // decode all tasks [i0, i1)
1602        if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1603            LOG_ERR("%s: llama_decode() failed\n", __func__);
1604            return;
1605        }
1606
1607        // Compute log-probs in parallel
1608        // First we collect all tasks
1609        eval_pairs.clear();
1610        for (size_t i = i0; i < i1; ++i) {
1611            auto& cur_task = tasks[i];
1612            size_t li = 1; // skip the last logit of the common prefix (computed separately below)
1613            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1614                for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1615                    eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
1616                }
1617            }
1618        }
1619        // Then we do the actual calculation
1620        compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
1621
1622        size_t ir = 0;
1623
1624        // compute the logprobs for each ending of the decoded tasks
1625        for (size_t i = i0; i < i1; ++i) {
1626            auto & cur_task = tasks[i];
1627            //LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
1628            //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
1629            //    if (cur_task.mc1.labels[j] == 1) {
1630            //        LOG("%d", j+1);
1631            //    }
1632            //}
1633            //LOG("\n    common_prefix: %zu\n", cur_task.common_prefix);
1634
1635            // get the logits of the last token of the common prefix
1636            std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float));
1637
1638            const auto first_probs = softmax(tok_logits);
1639
1640            cur_task.log_probs.resize(cur_task.seq_tokens.size());
1641            for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
1642                size_t count = 1;
1643                float  log_prob  = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
1644                for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1645                    //LOG("        %zu  %g\n", ir, eval_results[ir]);
1646                    ++count;
1647                    log_prob += eval_results[ir++];
1648                }
1649                cur_task.log_probs[s] = log_prob / count;
1650                //LOG("        Final: %g\n", log_prob / count);
1651                //LOG("    <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
1652            }
1653
1654            // Find the ending with maximum logprob
1655            size_t logprob_max_idx = 0;
1656            float  logprob_max_val = cur_task.log_probs[0];
1657            for (size_t s = 1; s < cur_task.log_probs.size(); s++) {
1658                if (cur_task.log_probs[s] > logprob_max_val) {
1659                    logprob_max_val = cur_task.log_probs[s];
1660                    logprob_max_idx = s;
1661                }
1662            }
1663
1664            n_tot_answers += cur_task.log_probs.size();
1665            if (cur_task.mc1.labels[logprob_max_idx] == 1) {
1666                ++n_correct;
1667            }
1668            ++n_done;
1669
1670            // Print the accumulated accuracy mean x 100
1671            LOG("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
1672        }
1673
1674        i0 = i1 - 1;
1675    }
1676
1677    llama_batch_free(batch);
1678
1679    if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
1680
1681    float p = 1.f*n_correct/n_done;
1682    float sigma = sqrt(p*(1-p)/(n_done-1));
1683    LOG("\n");
1684    LOG_INF("Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1685    p = 1.f*n_done/n_tot_answers;
1686    sigma = sqrt(p*(1-p)/(n_done-1));
1687    LOG_INF("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1688
1689    LOG_INF("\n");
1690}
1691
1692static void kl_divergence(llama_context * ctx, const common_params & params) {
1693    const llama_model * model = llama_get_model(ctx);
1694    const llama_vocab * vocab = llama_model_get_vocab(model);
1695
1696    if (params.logits_file.empty()) {
1697        LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
1698        return;
1699    }
1700    std::ifstream in(params.logits_file.c_str(), std::ios::binary);
1701    if (!in) {
1702        LOG_ERR("%s: failed to open %s\n", __func__, params.logits_file.c_str());
1703        return;
1704    }
1705    {
1706        char check[9]; check[8] = 0;
1707        in.read(check, 8);
1708        if (in.fail() || strncmp("_logits_", check, 8) != 0) {
1709            LOG_ERR("%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
1710            return;
1711        }
1712    }
1713
1714    uint32_t n_ctx;
1715    in.read((char *)&n_ctx, sizeof(n_ctx));
1716    if (n_ctx > llama_n_ctx(ctx)) {
1717        LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
1718                __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
1719    }
1720
1721    int n_vocab;
1722    int n_chunk;
1723    in.read((char *)&n_vocab, sizeof(n_vocab));
1724    in.read((char *)&n_chunk, sizeof(n_chunk));
1725    if (in.fail()) {
1726        LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1727        return;
1728    }
1729    if (n_vocab != llama_vocab_n_tokens(vocab)) {
1730        LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab));
1731    }
1732
1733    std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
1734    if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
1735        LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
1736        return;
1737    }
1738
1739    const int n_batch = params.n_batch;
1740    const int num_batches = (n_ctx + n_batch - 1)/n_batch;
1741    const int nv = 2*((n_vocab + 1)/2) + 4;
1742    const bool add_bos = llama_vocab_get_add_bos(vocab);
1743    GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
1744
1745    std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
1746    std::vector<float>    kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1747    std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1748    std::vector<float> logits;
1749    if (num_batches > 1) {
1750        logits.reserve(size_t(n_ctx) * n_vocab);
1751    }
1752
1753    std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
1754
1755    auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
1756        if (count < 1) {
1757            return std::make_pair(0., 0.);
1758        }
1759        double f = sum/count;
1760        double df = sum2/count - f*f;
1761        df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.;
1762        return std::make_pair(f, df);
1763    };
1764    auto covariance = [] (double suma, double sumb, double sumab, size_t count) {
1765        if (count < 10) {
1766            return 0.0;
1767        }
1768        double var = sumab/count - (suma/count)*(sumb/count);
1769        var /= count - 1;
1770        return var;
1771    };
1772
1773    kl_divergence_result kld;
1774    auto    kld_ptr =    kld_values.data();
1775    auto p_diff_ptr = p_diff_values.data();
1776
1777    for (int i = 0; i < n_chunk; ++i) {
1778        const int start =     i * n_ctx;
1779        const int end   = start + n_ctx;
1780
1781        const auto t_start = std::chrono::high_resolution_clock::now();
1782
1783        if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
1784            LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
1785            return;
1786        }
1787
1788        // clear the KV cache
1789        llama_memory_clear(llama_get_memory(ctx), true);
1790
1791        llama_batch batch = llama_batch_init(n_batch, 0, 1);
1792
1793        for (int j = 0; j < num_batches; ++j) {
1794            const int batch_start = start + j * n_batch;
1795            const int batch_size  = std::min(end - batch_start, n_batch);
1796
1797            // save original token and restore it after eval
1798            const auto token_org = tokens[batch_start];
1799
1800            // add BOS token for the first batch of each chunk
1801            if (add_bos && j == 0) {
1802                tokens[batch_start] = llama_vocab_bos(vocab);
1803            }
1804
1805            common_batch_clear(batch);
1806            for (int i = 0; i < batch_size; i++) {
1807                common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
1808            }
1809
1810            if (llama_decode(ctx, batch)) {
1811                LOG_ERR("%s : failed to eval\n", __func__);
1812                llama_batch_free(batch);
1813                return;
1814            }
1815
1816            // restore the original token in case it was set to BOS
1817            tokens[batch_start] = token_org;
1818
1819            if (num_batches > 1) {
1820                const auto * batch_logits = llama_get_logits(ctx);
1821                logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
1822            }
1823        }
1824
1825        llama_batch_free(batch);
1826
1827        const auto t_end = std::chrono::high_resolution_clock::now();
1828
1829        if (i == 0) {
1830            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
1831            LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
1832            int total_seconds = (int)(t_total * n_chunk);
1833            if (total_seconds >= 60*60) {
1834                LOG("%d hours ", total_seconds / (60*60));
1835                total_seconds = total_seconds % (60*60);
1836            }
1837            LOG("%.2f minutes\n", total_seconds / 60.0);
1838        }
1839        LOG("\n");
1840        LOG("chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p\n");
1841
1842        const int first = n_ctx/2;
1843        const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
1844        process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1845                workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
1846        p_diff_ptr += n_ctx - 1 - first;
1847        kld_ptr    += n_ctx - 1 - first;
1848
1849        LOG("%4d", i+1);
1850
1851        auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1852        const double ppl_val = exp(log_ppl.first);
1853        const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1854        LOG("    %9.4lf ± %9.4lf", ppl_val, ppl_unc);
1855
1856        auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1857        const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1858        const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1859        const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1860        LOG("    %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
1861
1862        auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1863        LOG("    %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
1864
1865        auto p_diff_mse   = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1866        const double p_diff_rms_val = sqrt(p_diff_mse.first);
1867        const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1868        LOG("    %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1869
1870        double p_top_val = 1.*kld.n_same_top/kld.count;
1871        double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
1872        LOG("    %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
1873
1874        LOG("\n");
1875
1876        logits.clear();
1877    }
1878    LOG("\n");
1879
1880    if (kld.count < 100) return; // we do not wish to do statistics on so few values
1881
1882    std::sort(kld_values.begin(), kld_values.end());
1883    std::sort(p_diff_values.begin(), p_diff_values.end());
1884
1885    LOG("====== Perplexity statistics ======\n");
1886
1887    auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1888    const double ppl_val = exp(log_ppl.first);
1889    const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1890    LOG("Mean PPL(Q)                   : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
1891
1892    auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1893    const double ppl_base_val = exp(log_ppl_base.first);
1894    const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
1895    LOG("Mean PPL(base)                : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
1896
1897    const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1898    // LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
1899    const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
1900    LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
1901
1902    const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1903    const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1904    LOG("Mean ln(PPL(Q)/PPL(base))     : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
1905
1906    const double ppl_ratio_val = exp(log_ppl_ratio_val);
1907    const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
1908    LOG("Mean PPL(Q)/PPL(base)         : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
1909
1910    const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
1911    const double ppl_diff_val = ppl_val - ppl_base_val;
1912    const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
1913    LOG("Mean PPL(Q)-PPL(base)         : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
1914
1915    LOG("\n");
1916
1917    LOG("====== KL divergence statistics ======\n");
1918    auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1919    LOG("Mean    KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
1920    auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
1921                                               : kld_values[kld_values.size()/2];
1922
1923    auto percentile = [] (std::vector<float> values, float fraction) {
1924        if (fraction <= 0) return values.front();
1925        if (fraction >= 1) return values.back();
1926        float p = fraction*(values.size() - 1);
1927        size_t ip = size_t(p); p -= ip;
1928        return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)];
1929    };
1930
1931    LOG("Maximum KLD: %10.6f\n", kld_values.back());
1932    LOG("99.9%%   KLD: %10.6f\n", percentile(kld_values, 0.999f));
1933    LOG("99.0%%   KLD: %10.6f\n", percentile(kld_values, 0.990f));
1934    LOG("95.0%%   KLD: %10.6f\n", percentile(kld_values, 0.950f));
1935    LOG("90.0%%   KLD: %10.6f\n", percentile(kld_values, 0.900f));
1936    LOG("Median  KLD: %10.6f\n", kld_median);
1937    LOG("10.0%%   KLD: %10.6f\n", percentile(kld_values, 0.100f));
1938    LOG(" 5.0%%   KLD: %10.6f\n", percentile(kld_values, 0.050f));
1939    LOG(" 1.0%%   KLD: %10.6f\n", percentile(kld_values, 0.010f));
1940    LOG(" 0.1%%   KLD: %10.6f\n", percentile(kld_values, 0.001f));
1941    LOG("Minimum KLD: %10.6f\n", kld_values.front());
1942
1943    LOG("\n");
1944
1945    LOG("====== Token probability statistics ======\n");
1946
1947    auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
1948    LOG("Mean    Δp: %6.3lf ± %5.3lf %%\n",  100.0*p_diff.first, 100.0*p_diff.second);
1949
1950    auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
1951                                               : p_diff_values[p_diff_values.size()/2];
1952
1953    LOG("Maximum Δp: %6.3lf%%\n",  100.0*p_diff_values.back());
1954    LOG("99.9%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
1955    LOG("99.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
1956    LOG("95.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
1957    LOG("90.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
1958    LOG("75.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
1959    LOG("Median  Δp: %6.3lf%%\n",  100.0*p_diff_median);
1960    LOG("25.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
1961    LOG("10.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
1962    LOG(" 5.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
1963    LOG(" 1.0%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
1964    LOG(" 0.1%%   Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
1965    LOG("Minimum Δp: %6.3lf%%\n",  100.0*p_diff_values.front());
1966
1967    auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1968    // LOG("MSE Δp    : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
1969
1970    const double p_diff_rms_val = sqrt(p_diff_mse.first);
1971    const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1972    LOG("RMS Δp    : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1973
1974    const double same_top_p = 1.0*kld.n_same_top/kld.count;
1975    LOG("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
1976}
1977
1978int main(int argc, char ** argv) {
1979    common_params params;
1980
1981    params.n_ctx = 512;
1982    params.escape = false;
1983
1984    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
1985        return 1;
1986    }
1987
1988    common_init();
1989
1990    const int32_t n_ctx = params.n_ctx;
1991
1992    if (n_ctx <= 0) {
1993        LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
1994        return 1;
1995    }
1996
1997    const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
1998
1999    if (ppl) {
2000        const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
2001        const int32_t n_kv = n_seq * n_ctx;
2002
2003        params.n_parallel = n_seq;
2004        params.n_ctx      = n_kv;
2005
2006        params.n_batch = std::min(params.n_batch, n_kv);
2007    } else {
2008        params.n_batch = std::min(params.n_batch, params.n_ctx);
2009        if (params.kl_divergence) {
2010            params.n_parallel = 1;
2011        } else {
2012            // ensure there's at least enough seq_ids for HellaSwag
2013            params.n_parallel = std::max(4, params.n_parallel);
2014        }
2015    }
2016
2017    if (params.ppl_stride > 0) {
2018        LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
2019                params.n_ctx, params.n_ctx + params.ppl_stride/2);
2020        params.n_ctx += params.ppl_stride/2;
2021    }
2022
2023    llama_backend_init();
2024    llama_numa_init(params.numa);
2025
2026    // load the model and apply lora adapter, if any
2027    auto llama_init = common_init_from_params(params);
2028
2029    auto * model = llama_init->model();
2030    auto * ctx   = llama_init->context();
2031
2032    if (model == NULL) {
2033        LOG_ERR("%s: unable to load model\n", __func__);
2034        return 1;
2035    }
2036
2037    const int n_ctx_train = llama_model_n_ctx_train(model);
2038
2039    if (params.n_ctx > n_ctx_train) {
2040        LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
2041                __func__, n_ctx_train, params.n_ctx);
2042    }
2043
2044    // print system information
2045    {
2046        LOG_INF("\n");
2047        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2048    }
2049
2050    struct results_perplexity results;
2051    if (params.hellaswag) {
2052        hellaswag_score(ctx, params);
2053    } else if (params.winogrande) {
2054        winogrande_score(ctx, params);
2055    } else if (params.multiple_choice) {
2056        multiple_choice_score(ctx, params);
2057    } else if (params.kl_divergence) {
2058        kl_divergence(ctx, params);
2059    } else {
2060        results = perplexity(ctx, params, n_ctx);
2061    }
2062
2063    LOG("\n");
2064    llama_perf_context_print(ctx);
2065    llama_memory_breakdown_print(ctx);
2066
2067    llama_backend_free();
2068
2069    return 0;
2070}