1// Test for state restore with fragmented KV cache
  2// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
  3// The issue was that state restore required contiguous KV cache slots,
  4// which fails when the cache is fragmented.
  5//
  6// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
  7// in state_read_meta(), allowing non-contiguous slot allocation.
  8
  9#include "arg.h"
 10#include "common.h"
 11#include "llama.h"
 12
 13#include <vector>
 14#include <cstdio>
 15#include <cstring>
 16
 17int main(int argc, char ** argv) {
 18    common_params params;
 19
 20    params.sampling.seed = 1234;
 21    params.kv_unified = true;
 22    params.n_parallel = 3;
 23    params.n_ctx = 256;
 24
 25    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
 26        return 1;
 27    }
 28
 29    common_init();
 30
 31    // init
 32    common_init_result_ptr llama_init = common_init_from_params(params);
 33
 34    llama_model * model = llama_init->model();
 35    llama_context * ctx = llama_init->context();
 36
 37    if (model == nullptr || ctx == nullptr) {
 38        fprintf(stderr, "%s : failed to init\n", __func__);
 39        return 1;
 40    }
 41
 42    GGML_UNUSED(model);
 43
 44    // tokenize prompt
 45    std::vector<llama_token> tokens(70, 1);
 46
 47    // interleave the 3 sequences:
 48    // 01201230123...
 49    llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
 50    for (size_t i = 0; i < tokens.size(); i++) {
 51        for (int s = 0; s < params.n_parallel; ++s) {
 52            common_batch_add(batch, tokens[i], i, {s}, false);
 53        }
 54    }
 55    batch.logits[batch.n_tokens - 1] = true;
 56
 57    if (llama_decode(ctx, batch)) {
 58        fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
 59        return 1;
 60    }
 61
 62    fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
 63
 64    // Save state of seq 1
 65    std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
 66    const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
 67    if (ncopy != seq_state.size()) {
 68        fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
 69        return 1;
 70    }
 71    fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
 72
 73    // clear seq 1 to create a "hole" in the KV cache (fragmentation)
 74    // 0.20.20.20.2....
 75    llama_memory_t mem = llama_get_memory(ctx);
 76    llama_memory_seq_rm(mem, 1, -1, -1);
 77    fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
 78
 79    // Now the cache has holes where seq 1 was
 80    // This creates fragmentation - there's no contiguous block large enough
 81    // for the seq 1 state if we only look for contiguous slots
 82
 83    // Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
 84    // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
 85    // Before the fix, this would fail with "failed to find available cells in kv cache"
 86    const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
 87    if (nset != seq_state.size()) {
 88        fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
 89                __func__, nset, seq_state.size());
 90        fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
 91        llama_batch_free(batch);
 92        return 1;
 93    }
 94    fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
 95
 96    // Verify we can decode with the restored state
 97    // Generate one token to verify the restored state is usable
 98    auto sparams = llama_sampler_chain_default_params();
 99    llama_sampler * smpl = llama_sampler_chain_init(sparams);
100    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
101
102    auto next_token = llama_sampler_sample(smpl, ctx, -1);
103    auto next_token_str = common_token_to_piece(ctx, next_token);
104
105    common_batch_clear(batch);
106    common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
107
108    if (llama_decode(ctx, batch)) {
109        fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
110        llama_sampler_free(smpl);
111        llama_batch_free(batch);
112        return 1;
113    }
114
115    fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
116    fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
117
118    llama_sampler_free(smpl);
119    llama_batch_free(batch);
120
121    return 0;
122}