1#include "llama-kv-cache.h"
   2
   3#include "llama-impl.h"
   4#include "llama-io.h"
   5#include "llama-model.h"
   6#include "llama-context.h"
   7
   8#include <algorithm>
   9#include <cassert>
  10#include <cmath>
  11#include <cstring>
  12#include <limits>
  13#include <map>
  14#include <stdexcept>
  15
  16//
  17// llama_kv_cache
  18//
  19
  20llama_kv_cache::llama_kv_cache(
  21        const llama_model & model,
  22                ggml_type   type_k,
  23                ggml_type   type_v,
  24                     bool   v_trans,
  25                     bool   offload,
  26                     bool   unified,
  27                 uint32_t   kv_size,
  28                 uint32_t   n_seq_max,
  29                 uint32_t   n_pad,
  30                 uint32_t   n_swa,
  31           llama_swa_type   swa_type,
  32    const layer_filter_cb & filter,
  33    const  layer_reuse_cb & reuse) :
  34    model(model), hparams(model.hparams), v_trans(v_trans),
  35    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  36
  37    GGML_ASSERT(kv_size % n_pad == 0);
  38
  39    const uint32_t n_layer_kv = hparams.n_layer_kv();
  40
  41    // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
  42    struct ggml_backend_buft_comparator {
  43        bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
  44            return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
  45        }
  46    };
  47    std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
  48
  49    // create a context for each buffer type
  50    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  51        auto it = ctx_map.find(buft);
  52        if (it == ctx_map.end()) {
  53            ggml_init_params params = {
  54                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
  55                /*.mem_buffer =*/ NULL,
  56                /*.no_alloc   =*/ true,
  57            };
  58
  59            ggml_context * ctx = ggml_init(params);
  60            if (!ctx) {
  61                return nullptr;
  62            }
  63
  64            ctx_map.emplace(buft, ctx);
  65
  66            return ctx;
  67        }
  68
  69        return it->second.get();
  70    };
  71
  72    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
  73
  74    v_heads.resize(n_stream);
  75    for (uint32_t s = 0; s < n_stream; ++s) {
  76        v_heads[s] = 0;
  77    }
  78
  79    v_cells.resize(n_stream);
  80    for (uint32_t s = 0; s < n_stream; ++s) {
  81        v_cells[s].resize(kv_size);
  82    }
  83
  84    // by default, all sequence ids are mapped to the 0th stream
  85    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
  86
  87    if (n_stream > 1) {
  88        seq_to_stream.resize(n_stream, 0);
  89        for (uint32_t s = 0; s < n_stream; ++s) {
  90            seq_to_stream[s] = s;
  91        }
  92    }
  93
  94    // [TAG_V_CACHE_VARIABLE]
  95    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
  96        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
  97                __func__, hparams.n_embd_v_gqa_max());
  98    }
  99
 100    const bool is_mla = hparams.is_mla();
 101
 102    for (uint32_t il = 0; il < hparams.n_layer; il++) {
 103        if (!hparams.has_kv(il)) {
 104            LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
 105            continue;
 106        }
 107
 108        if (filter && !filter(il)) {
 109            LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
 110            continue;
 111        }
 112
 113        // [TAG_V_CACHE_VARIABLE]
 114        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
 115        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
 116
 117        const char * dev_name = "CPU";
 118
 119        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
 120
 121        if (offload) {
 122            auto * dev = model.dev_layer(il);
 123            buft = ggml_backend_dev_buffer_type(dev);
 124
 125            dev_name = ggml_backend_dev_name(dev);
 126        }
 127
 128        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
 129
 130        ggml_context * ctx = ctx_for_buft(buft);
 131        if (!ctx) {
 132            throw std::runtime_error("failed to create ggml context for kv cache");
 133        }
 134
 135        const bool has_k = true;
 136        const bool has_v = !is_mla;
 137
 138        ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
 139        ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
 140
 141        has_k && ggml_format_name(k, "cache_k_l%d", il);
 142        has_v && ggml_format_name(v, "cache_v_l%d", il);
 143
 144        std::vector<ggml_tensor *> k_stream;
 145        std::vector<ggml_tensor *> v_stream;
 146
 147        for (uint32_t s = 0; s < n_stream; ++s) {
 148            k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
 149            v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
 150        }
 151
 152        map_layer_ids[il] = layers.size();
 153
 154        layers.push_back({ il, k, v, k_stream, v_stream, });
 155    }
 156
 157    if (reuse) {
 158        LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
 159
 160        for (uint32_t il = 0; il < hparams.n_layer; il++) {
 161            const int32_t il_reuse = reuse(il);
 162
 163            if (il_reuse < 0) {
 164                LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
 165                continue;
 166            }
 167
 168            if (filter && !filter(il)) {
 169                LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
 170                continue;
 171            }
 172
 173            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
 174
 175            map_layer_ids[il] = map_layer_ids[il_reuse];
 176
 177            LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
 178        }
 179    }
 180
 181    // allocate tensors and initialize the buffers to avoid NaNs in the padding
 182    for (auto & [buft, ctx] : ctx_map) {
 183        ggml_backend_buffer_t buf;
 184        if (model.hparams.no_alloc) {
 185            buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
 186            for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
 187                t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
 188            }
 189        } else {
 190            buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
 191        }
 192        if (!buf) {
 193            throw std::runtime_error("failed to allocate buffer for kv cache");
 194        }
 195
 196        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
 197
 198        ggml_backend_buffer_clear(buf, 0);
 199        ctxs_bufs.emplace_back(std::move(ctx), buf);
 200    }
 201
 202    {
 203        const size_t memory_size_k = size_k_bytes();
 204        const size_t memory_size_v = size_v_bytes();
 205
 206        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
 207                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
 208                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
 209                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
 210    }
 211
 212    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
 213    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 214}
 215
 216void llama_kv_cache::clear(bool data) {
 217    for (uint32_t s = 0; s < n_stream; ++s) {
 218        v_cells[s].reset();
 219        v_heads[s] = 0;
 220    }
 221
 222    if (data) {
 223        for (auto & [_, buf] : ctxs_bufs) {
 224            ggml_backend_buffer_clear(buf.get(), 0);
 225        }
 226    }
 227}
 228
 229bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 230    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 231
 232    if (p0 < 0) {
 233        p0 = 0;
 234    }
 235
 236    if (p1 < 0) {
 237        p1 = std::numeric_limits<llama_pos>::max();
 238    }
 239
 240    if (seq_id >= 0) {
 241        auto & cells = v_cells[seq_to_stream[seq_id]];
 242        auto & head  = v_heads[seq_to_stream[seq_id]];
 243
 244        uint32_t new_head = cells.size();
 245
 246        for (uint32_t i = 0; i < cells.size(); ++i) {
 247            if (!cells.pos_in(i, p0, p1)) {
 248                continue;
 249            }
 250
 251            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
 252                if (new_head == cells.size()) {
 253                    new_head = i;
 254                }
 255            }
 256        }
 257
 258        // If we freed up a slot, set head to it so searching can start there.
 259        if (new_head != cells.size() && new_head < head) {
 260            head = new_head;
 261        }
 262    } else {
 263        // match any sequence
 264        for (uint32_t s = 0; s < n_stream; ++s) {
 265            auto & cells = v_cells[s];
 266            auto & head  = v_heads[s];
 267
 268            uint32_t new_head = cells.size();
 269
 270            for (uint32_t i = 0; i < cells.size(); ++i) {
 271                if (!cells.pos_in(i, p0, p1)) {
 272                    continue;
 273                }
 274
 275                cells.rm(i);
 276
 277                if (new_head == cells.size()) {
 278                    new_head = i;
 279                }
 280            }
 281
 282            // If we freed up a slot, set head to it so searching can start there.
 283            if (new_head != cells.size() && new_head < head) {
 284                head = new_head;
 285            }
 286        }
 287    }
 288
 289    return true;
 290}
 291
 292void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
 293    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
 294    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
 295
 296    const auto s0 = seq_to_stream[seq_id_src];
 297    const auto s1 = seq_to_stream[seq_id_dst];
 298
 299    if (s0 == s1) {
 300        // since both sequences are in the same stream, no data copy is necessary
 301        // we just have to update the cells meta data
 302
 303        auto & cells = v_cells[s0];
 304
 305        if (seq_id_src == seq_id_dst) {
 306            return;
 307        }
 308
 309        if (p0 < 0) {
 310            p0 = 0;
 311        }
 312
 313        if (p1 < 0) {
 314            p1 = std::numeric_limits<llama_pos>::max();
 315        }
 316
 317        for (uint32_t i = 0; i < cells.size(); ++i) {
 318            if (!cells.pos_in(i, p0, p1)) {
 319                continue;
 320            }
 321
 322            if (cells.seq_has(i, seq_id_src)) {
 323                cells.seq_add(i, seq_id_dst);
 324            }
 325        }
 326
 327        return;
 328    }
 329
 330    // cross-stream sequence copies require to copy the actual buffer data
 331
 332    bool is_full = true;
 333
 334    if (p0 > 0 && p0 + 1 < (int) get_size()) {
 335        is_full = false;
 336    }
 337
 338    if (p1 > 0 && p1 + 1 < (int) get_size()) {
 339        is_full = false;
 340    }
 341
 342    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
 343
 344    // enqueue the copy operation - the buffer copy will be performed during the next update
 345    sc_info.ssrc.push_back(s0);
 346    sc_info.sdst.push_back(s1);
 347
 348    v_cells[s1].reset();
 349    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
 350        if (v_cells[s0].seq_has(i, seq_id_src)) {
 351            llama_pos pos   = v_cells[s0].pos_get(i);
 352            llama_pos shift = v_cells[s0].get_shift(i);
 353
 354            llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
 355
 356            if (shift != 0) {
 357                pos -= shift;
 358                assert(pos >= 0);
 359            }
 360
 361            v_cells[s1].pos_set(i, pos);
 362            v_cells[s1].seq_add(i, seq_id_dst);
 363
 364            if (shift != 0) {
 365                v_cells[s1].pos_add(i, shift);
 366            }
 367
 368            v_cells[s1].ext_set(i, ext);
 369        }
 370    }
 371
 372    v_heads[s1] = v_heads[s0];
 373
 374    //for (uint32_t s = 0; s < n_stream; ++s) {
 375    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
 376    //}
 377}
 378
 379void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
 380    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 381
 382    auto & cells = v_cells[seq_to_stream[seq_id]];
 383    auto & head  = v_heads[seq_to_stream[seq_id]];
 384
 385    uint32_t new_head = cells.size();
 386
 387    for (uint32_t i = 0; i < cells.size(); ++i) {
 388        if (cells.seq_keep(i, seq_id)) {
 389            if (new_head == cells.size()) {
 390                new_head = i;
 391            }
 392        }
 393    }
 394
 395    // If we freed up a slot, set head to it so searching can start there.
 396    if (new_head != cells.size() && new_head < head) {
 397        head = new_head;
 398    }
 399}
 400
 401void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
 402    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 403    GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
 404
 405    auto & cells = v_cells[seq_to_stream[seq_id]];
 406    auto & head  = v_heads[seq_to_stream[seq_id]];
 407
 408    if (shift == 0) {
 409        return;
 410    }
 411
 412    uint32_t new_head = cells.size();
 413
 414    if (p0 < 0) {
 415        p0 = 0;
 416    }
 417
 418    if (p1 < 0) {
 419        p1 = std::numeric_limits<llama_pos>::max();
 420    }
 421
 422    // If there is no range then return early to avoid looping over all cells.
 423    if (p0 == p1) {
 424        return;
 425    }
 426
 427    for (uint32_t i = 0; i < cells.size(); ++i) {
 428        if (!cells.pos_in(i, p0, p1)) {
 429            continue;
 430        }
 431
 432        if (cells.seq_has(i, seq_id)) {
 433            if (cells.pos_add(i, shift)) {
 434                if (new_head == cells.size()) {
 435                    new_head = i;
 436                }
 437            }
 438        }
 439    }
 440
 441    // If we freed up a slot, set head to it so searching can start there.
 442    // Otherwise we just start the next search from the beginning.
 443    head = new_head != cells.size() ? new_head : 0;
 444}
 445
 446void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
 447    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 448    GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
 449
 450    auto & cells = v_cells[seq_to_stream[seq_id]];
 451
 452    if (d == 1) {
 453        return;
 454    }
 455
 456    if (p0 < 0) {
 457        p0 = 0;
 458    }
 459
 460    if (p1 < 0) {
 461        p1 = std::numeric_limits<llama_pos>::max();
 462    }
 463
 464    // If there is no range then return early to avoid looping over the cache.
 465    if (p0 == p1) {
 466        return;
 467    }
 468
 469    for (uint32_t i = 0; i < cells.size(); ++i) {
 470        if (!cells.pos_in(i, p0, p1)) {
 471            continue;
 472        }
 473
 474        if (cells.seq_has(i, seq_id)) {
 475            cells.pos_div(i, d);
 476        }
 477    }
 478}
 479
 480llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
 481    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 482
 483    const auto & cells = v_cells[seq_to_stream[seq_id]];
 484
 485    return cells.seq_pos_min(seq_id);
 486}
 487
 488llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
 489    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 490
 491    const auto & cells = v_cells[seq_to_stream[seq_id]];
 492
 493    return cells.seq_pos_max(seq_id);
 494}
 495
 496std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
 497    std::map<ggml_backend_buffer_type_t, size_t> ret;
 498    for (const auto & [ctx, buf] : ctxs_bufs) {
 499        ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
 500
 501        if (hparams.no_alloc) {
 502            GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
 503            ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
 504        } else {
 505            // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
 506            ret[buft] += ggml_backend_buffer_get_size(buf.get());
 507        }
 508    }
 509
 510    return ret;
 511}
 512
 513llama_memory_context_ptr llama_kv_cache::init_batch(
 514            llama_batch_allocr & balloc,
 515            uint32_t n_ubatch,
 516            bool embd_all) {
 517    GGML_UNUSED(embd_all);
 518
 519    do {
 520        balloc.split_reset();
 521
 522        std::vector<llama_ubatch> ubatches;
 523        while (true) {
 524            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
 525
 526            if (ubatch.n_tokens == 0) {
 527                break;
 528            }
 529
 530            ubatches.push_back(std::move(ubatch)); // NOLINT
 531        }
 532
 533        if (balloc.get_n_used() < balloc.get_n_tokens()) {
 534            // failed to find a suitable split
 535            break;
 536        }
 537
 538        auto sinfos = prepare(ubatches);
 539        if (sinfos.empty()) {
 540            break;
 541        }
 542
 543        return std::make_unique<llama_kv_cache_context>(
 544                this, std::move(sinfos), std::move(ubatches));
 545    } while (false);
 546
 547    return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 548}
 549
 550llama_memory_context_ptr llama_kv_cache::init_full() {
 551    return std::make_unique<llama_kv_cache_context>(this);
 552}
 553
 554llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
 555    GGML_UNUSED(optimize);
 556
 557    bool do_shift = get_has_shift();
 558
 559    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
 560}
 561
 562llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
 563    llama_kv_cache::slot_info_vec_t res;
 564
 565    struct state_t {
 566        slot_info sinfo; // slot info for the ubatch
 567
 568        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
 569
 570        std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
 571    };
 572
 573    // remember the old state of the cells so we can restore it in the end
 574    std::vector<state_t> states;
 575
 576    bool success = true;
 577
 578    for (const auto & ubatch : ubatches) {
 579        // only find a suitable slot for the ubatch. don't modify the cells yet
 580        const auto sinfo_new = find_slot(ubatch, false);
 581        if (sinfo_new.empty()) {
 582            success = false;
 583            break;
 584        }
 585
 586        // remeber the position that we found
 587        res.push_back(sinfo_new);
 588
 589        // store the old state of the cells in the recovery stack
 590        {
 591            state_t state = { sinfo_new, v_heads, {} };
 592
 593            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
 594                auto & cells = v_cells[sinfo_new.strm[s]];
 595
 596                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
 597            }
 598
 599            states.push_back(std::move(state));
 600        }
 601
 602        // now emplace the ubatch
 603        apply_ubatch(sinfo_new, ubatch);
 604    }
 605
 606    GGML_ASSERT(!states.empty() || !success);
 607
 608    // iterate backwards and restore the cells to their original state
 609    for (auto it = states.rbegin(); it != states.rend(); ++it) {
 610        const auto & sinfo = it->sinfo;
 611
 612        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
 613            auto & cells = v_cells[sinfo.strm[s]];
 614            auto & head  = v_heads[sinfo.strm[s]];
 615
 616            cells.set(sinfo.idxs[s], it->v_cells[s]);
 617            head = it->v_heads_old[s];
 618        }
 619    }
 620
 621    if (!success) {
 622        return {};
 623    }
 624
 625    return res;
 626}
 627
 628bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
 629    bool updated = false;
 630
 631    auto * sched = lctx->get_sched();
 632
 633    if (!sc_info.empty()) {
 634        assert(n_stream > 1 && "stream copy should never happen with a single stream");
 635
 636        llama_synchronize(lctx);
 637
 638        const size_t n_copy = sc_info.ssrc.size();
 639
 640        for (size_t i = 0; i < n_copy; ++i) {
 641            const auto ssrc = sc_info.ssrc[i];
 642            const auto sdst = sc_info.sdst[i];
 643
 644            assert(ssrc < n_stream);
 645            assert(sdst < n_stream);
 646
 647            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
 648
 649            assert(ssrc != sdst);
 650
 651            for (uint32_t il = 0; il < layers.size(); ++il) {
 652                const auto & layer = layers[il];
 653
 654                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
 655
 656                if (layer.v_stream[ssrc]) {
 657                    ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
 658                }
 659            }
 660        }
 661    }
 662
 663    if (do_shift) {
 664        if (!get_can_shift()) {
 665            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
 666        }
 667
 668        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
 669
 670        // apply K-shift if needed
 671        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
 672            ggml_backend_sched_reset(sched);
 673
 674            auto * res = lctx->get_gf_res_reserve();
 675
 676            res->reset();
 677
 678            auto * gf = build_graph_shift(res, lctx);
 679            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
 680                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
 681                return updated;
 682            }
 683
 684            res->set_inputs(nullptr);
 685
 686            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
 687                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
 688                return updated;
 689            }
 690
 691            updated = true;
 692        }
 693
 694        for (uint32_t s = 0; s < n_stream; ++s) {
 695            auto & cells = v_cells[s];
 696
 697            cells.reset_shift();
 698        }
 699    }
 700
 701    return updated;
 702}
 703
 704llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
 705
 706    if (debug > 0) {
 707        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
 708            const auto seq_id = ubatch.seq_id_unq[s];
 709            const auto stream_id = seq_to_stream[seq_id];
 710            const auto & cells = v_cells[stream_id];
 711            const uint32_t head_cur = v_heads[stream_id];
 712
 713            LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
 714                    __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
 715
 716            if ((debug == 2 && n_swa > 0) || debug > 2) {
 717                std::string ss;
 718                for (uint32_t i = 0; i < cells.size(); ++i) {
 719                    if (cells.is_empty(i)) {
 720                        ss += '.';
 721                    } else {
 722                        assert(cells.seq_count(i) >= 1);
 723
 724                        if (cells.seq_count(i) == 1) {
 725                            ss += std::to_string(cells.seq_get(i));
 726                        } else {
 727                            ss += 'M';
 728                        }
 729                    }
 730                    if (i%256 == 255) {
 731                        ss += " *";
 732                        ss += '\n';
 733                    }
 734                }
 735                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
 736            }
 737
 738            if ((debug == 2 && n_swa > 0) || debug > 2) {
 739                std::string ss;
 740                for (uint32_t i = 0; i < cells.size(); ++i) {
 741                    std::string cur;
 742                    if (cells.is_empty(i)) {
 743                        cur = '.';
 744                    } else {
 745                        cur = std::to_string(cells.pos_get(i));
 746                    }
 747                    const int n = cur.size();
 748                    for (int j = 0; j < 5 - n; ++j) {
 749                        cur += ' ';
 750                    }
 751                    ss += cur;
 752                    if (i%256 == 255) {
 753                        ss += " *";
 754                    }
 755                    if (i%64 == 63) {
 756                        ss += '\n';
 757                    }
 758                }
 759                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
 760            }
 761
 762            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
 763                if (cells.seq_pos_min(s) < 0) {
 764                    continue;
 765                }
 766
 767                LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
 768            }
 769        }
 770    }
 771
 772    uint32_t n_tokens = ubatch.n_tokens;
 773    uint32_t n_seqs   = 1;
 774
 775    if (n_stream > 1) {
 776        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
 777
 778        n_seqs   = ubatch.n_seqs_unq;
 779        n_tokens = n_tokens / n_seqs;
 780    }
 781
 782    slot_info res = {
 783        /*.s0   =*/ LLAMA_MAX_SEQ,
 784        /*.s1   =*/ 0,
 785        /*.strm =*/ { },
 786        /*.idxs =*/ { },
 787    };
 788
 789    res.resize(n_seqs);
 790
 791    for (uint32_t s = 0; s < n_seqs; ++s) {
 792        const auto seq_id = ubatch.seq_id_unq[s];
 793
 794        if (n_stream > 1) {
 795            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
 796            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
 797        }
 798
 799        res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
 800        res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
 801
 802        res.strm[s] = seq_to_stream[seq_id];
 803        res.idxs[s].reserve(n_tokens);
 804
 805        const auto & cells = v_cells[seq_to_stream[seq_id]];
 806
 807        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
 808
 809        // if we have enough unused cells before the current head ->
 810        //   better to start searching from the beginning of the cache, hoping to fill it
 811        if (head_cur > cells.get_used() + 2*n_tokens) {
 812            head_cur = 0;
 813        }
 814
 815        if (n_tokens > cells.size()) {
 816            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
 817            return { };
 818        }
 819
 820        uint32_t n_tested = 0;
 821
 822        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
 823        // for non-continuous slots, we test the tokens one by one
 824        const uint32_t n_test = cont ? n_tokens : 1;
 825
 826        while (true) {
 827            if (head_cur + n_test > cells.size()) {
 828                n_tested += cells.size() - head_cur;
 829                head_cur = 0;
 830                continue;
 831            }
 832
 833            for (uint32_t i = 0; i < n_test; i++) {
 834                const auto idx = head_cur;
 835
 836                head_cur++;
 837                n_tested++;
 838
 839                //const llama_pos    pos    = ubatch.pos[i];
 840                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 841
 842                // can we use this cell? either:
 843                //  - the cell is empty
 844                //  - the cell is occupied only by one sequence:
 845                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
 846                //    - mask SWA, using current max pos for that sequence in the cache
 847                //                always insert in the cell with minimum pos
 848                bool can_use = cells.is_empty(idx);
 849
 850                if (!can_use && cells.seq_count(idx) == 1) {
 851                    const llama_pos pos_cell = cells.pos_get(idx);
 852
 853                    // (disabled) causal mask
 854                    // note: it's better to purge any "future" tokens beforehand
 855                    //if (cells.seq_has(idx, seq_id)) {
 856                    //    can_use = pos_cell >= pos;
 857                    //}
 858
 859                    if (!can_use) {
 860                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
 861
 862                        // SWA mask
 863                        if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
 864                            can_use = true;
 865                        }
 866                    }
 867                }
 868
 869                if (can_use) {
 870                    res.idxs[s].push_back(idx);
 871                } else {
 872                    if (cont) {
 873                        break;
 874                    }
 875                }
 876            }
 877
 878            if (res.idxs[s].size() == n_tokens) {
 879                break;
 880            }
 881
 882            if (cont) {
 883                res.idxs[s].clear();
 884            }
 885
 886            if (n_tested >= cells.size()) {
 887                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
 888                return { };
 889            }
 890        }
 891
 892        // we didn't find a suitable slot - return empty result
 893        if (res.idxs[s].size() < n_tokens) {
 894            return { };
 895        }
 896    }
 897
 898    assert(res.s1 >= res.s0);
 899
 900    return res;
 901}
 902
 903void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
 904    // keep track of the max sequence position that we would overwrite with this ubatch
 905    // for non-SWA cache, this would be always empty
 906    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
 907    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
 908        seq_pos_max_rm[s] = -1;
 909    }
 910
 911    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
 912
 913    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
 914        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
 915            const uint32_t i = s*sinfo.size() + ii;
 916
 917            auto & cells = v_cells[sinfo.strm[s]];
 918
 919            const auto idx = sinfo.idxs[s][ii];
 920
 921            if (!cells.is_empty(idx)) {
 922                assert(cells.seq_count(idx) == 1);
 923
 924                const llama_seq_id seq_id = cells.seq_get(idx);
 925                const llama_pos    pos    = cells.pos_get(idx);
 926
 927                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 928
 929                cells.rm(idx);
 930            }
 931
 932            cells.pos_set(idx, ubatch.pos[i]);
 933
 934            if (ubatch.is_pos_2d()) {
 935                llama_kv_cell_ext ext {
 936                    /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
 937                    /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
 938                };
 939                cells.ext_set(idx, ext);
 940            }
 941
 942            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
 943                cells.seq_add(idx, ubatch.seq_id[i][s]);
 944            }
 945        }
 946    }
 947
 948    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
 949    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
 950    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
 951    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
 952        if (seq_pos_max_rm[s] == -1) {
 953            continue;
 954        }
 955
 956        GGML_ASSERT(s < seq_to_stream.size());
 957
 958        auto & cells = v_cells[seq_to_stream[s]];
 959
 960        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
 961            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
 962                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
 963
 964            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
 965        }
 966    }
 967
 968    // move the head at the end of the slot
 969    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
 970        auto & head = v_heads[sinfo.strm[s]];
 971
 972        head = sinfo.idxs[s].back() + 1;
 973    }
 974}
 975
 976bool llama_kv_cache::get_can_shift() const {
 977    // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot.
 978    if (model.arch == LLM_ARCH_STEP35) {
 979        return false;
 980    }
 981    return true;
 982}
 983
 984uint32_t llama_kv_cache::get_size() const {
 985    const auto & cells = v_cells[seq_to_stream[0]];
 986
 987    return cells.size();
 988}
 989
 990uint32_t llama_kv_cache::get_n_stream() const {
 991    return n_stream;
 992}
 993
 994bool llama_kv_cache::get_has_shift() const {
 995    bool result = false;
 996
 997    for (uint32_t s = 0; s < n_stream; ++s) {
 998        result |= v_cells[s].get_has_shift();
 999    }
1000
1001    return result;
1002}
1003
1004uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
1005    uint32_t result = 0;
1006
1007    // pad the n_kv value so that the graph remains constant across batches and can be reused
1008    // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
1009    const uint32_t n_pad_cur = std::max(n_pad, 256u);
1010
1011    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1012        const auto & cells = v_cells[sinfo.strm[s]];
1013
1014        result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
1015    }
1016
1017    return result;
1018}
1019
1020ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1021    const int32_t ikv = map_layer_ids.at(il);
1022
1023    auto * k = layers[ikv].k;
1024
1025    const uint64_t kv_size      = get_size();
1026    const uint64_t n_embd_k_gqa = k->ne[0];
1027
1028    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1029
1030    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1031
1032    return ggml_view_4d(ctx, k,
1033            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
1034            ggml_row_size(k->type, hparams.n_embd_head_k),
1035            ggml_row_size(k->type, n_embd_k_gqa),
1036            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1037            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
1038}
1039
1040ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1041    const int32_t ikv = map_layer_ids.at(il);
1042
1043    auto * v = layers[ikv].v;
1044
1045    const uint64_t kv_size      = get_size();
1046    const uint64_t n_embd_v_gqa = v->ne[0];
1047
1048    // [TAG_V_CACHE_VARIABLE]
1049    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1050
1051    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1052
1053    if (!v_trans) {
1054        // note: v->nb[1] <= v->nb[2]
1055        return ggml_view_4d(ctx, v,
1056                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1057                ggml_row_size(v->type, hparams.n_embd_head_v),          // v->nb[1]
1058                ggml_row_size(v->type, n_embd_v_gqa),                   // v->nb[2]
1059                ggml_row_size(v->type, n_embd_v_gqa*kv_size),           // v->nb[3]
1060                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
1061    }
1062
1063    // note: v->nb[1] > v->nb[2]
1064    return ggml_view_4d(ctx, v,
1065            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1066            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),  // v->nb[1]
1067            ggml_row_size(v->type, kv_size),                        // v->nb[2]
1068            ggml_row_size(v->type, kv_size*n_embd_v_gqa),           // v->nb[3]
1069            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
1070}
1071
1072ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1073    GGML_UNUSED(sinfo);
1074
1075    const int32_t ikv = map_layer_ids.at(il);
1076
1077    ggml_tensor * k = layers[ikv].k;
1078
1079    const int64_t n_embd_head = k_cur->ne[0];
1080    const int64_t n_head      = k_cur->ne[1];
1081    const int64_t n_tokens    = k_cur->ne[2];
1082
1083    const int64_t n_embd_gqa = n_embd_head*n_head;
1084
1085    // we can merge dims 0 and 1
1086    // TODO: add ggml helper function for this?
1087    GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
1088
1089    k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
1090
1091    const int64_t n_stream = k->ne[2];
1092
1093    if (n_stream > 1) {
1094        const int64_t kv_size = get_size();
1095
1096        assert(n_embd_gqa == k->ne[0]);
1097        assert(kv_size    == k->ne[1]);
1098
1099        // merge the buffer across all streams because the idxs are global
1100        k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
1101    }
1102
1103    // store the current K values into the cache
1104    return ggml_set_rows(ctx, k, k_cur, k_idxs);
1105}
1106
1107ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1108    GGML_UNUSED(sinfo);
1109
1110    const int32_t ikv = map_layer_ids.at(il);
1111
1112    auto * v = layers[ikv].v;
1113
1114    const int64_t n_embd_head = v_cur->ne[0];
1115    const int64_t n_head      = v_cur->ne[1];
1116    const int64_t n_tokens    = v_cur->ne[2];
1117
1118    const int64_t n_embd_gqa = n_embd_head*n_head;
1119
1120    // we can merge dims 0 and 1
1121    GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
1122
1123    const int64_t n_stream = v->ne[2];
1124
1125    // take this branch when FA is enabled (the V cache is not transposed)
1126    if (!v_trans) {
1127        v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
1128
1129        if (n_stream > 1) {
1130            const int64_t kv_size = get_size();
1131
1132            assert(n_embd_gqa == v->ne[0]);
1133            assert(kv_size    == v->ne[1]);
1134
1135            // merge the buffer across all streams because the idxs are global
1136            v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
1137        }
1138
1139        return ggml_set_rows(ctx, v, v_cur, v_idxs);
1140    }
1141
1142    if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
1143        // we can merge dims 0, 1 and 2
1144        v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
1145    } else {
1146        // otherwise -> make a copy to get contiguous data
1147        v_cur = ggml_cont_2d   (ctx, v_cur, n_embd_gqa, n_tokens);
1148    }
1149
1150    // [TAG_V_CACHE_VARIABLE]
1151    if (n_embd_gqa < v->ne[0]) {
1152        v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
1153    }
1154
1155    // in this branch the v_idxs are constructed in such a way that each row is a single head element
1156    ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
1157
1158    v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
1159
1160    return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
1161}
1162
1163ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1164    const uint32_t n_tokens = ubatch.n_tokens;
1165
1166    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1167
1168    ggml_set_input(k_idxs);
1169
1170    return k_idxs;
1171}
1172
1173ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1174    const uint32_t n_tokens = ubatch.n_tokens;
1175
1176    ggml_tensor * v_idxs;
1177
1178    if (!v_trans) {
1179        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1180    } else {
1181        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1182    }
1183
1184    ggml_set_input(v_idxs);
1185
1186    return v_idxs;
1187}
1188
1189void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1190    const uint32_t n_tokens = ubatch->n_tokens;
1191    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1192
1193    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1194    int64_t * data = (int64_t *) dst->data;
1195
1196    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1197        const int64_t offs = sinfo.strm[s]*get_size();
1198
1199        for (uint32_t i = 0; i < sinfo.size(); ++i) {
1200            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1201        }
1202    }
1203}
1204
1205void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1206    const uint32_t n_tokens = ubatch->n_tokens;
1207    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1208
1209    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1210    int64_t * data = (int64_t *) dst->data;
1211
1212    if (!v_trans) {
1213        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1214            const int64_t offs = sinfo.strm[s]*get_size();
1215
1216            for (uint32_t i = 0; i < sinfo.size(); ++i) {
1217                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1218            }
1219        }
1220    } else {
1221        // note: the V cache is transposed when not using flash attention
1222        const int64_t kv_size = get_size();
1223
1224        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1225
1226        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1227            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1228
1229            for (uint32_t i = 0; i < sinfo.size(); ++i) {
1230                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1231                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1232                }
1233            }
1234        }
1235    }
1236}
1237
1238void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
1239    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1240
1241    int32_t * data = (int32_t *) dst->data;
1242
1243    for (uint32_t s = 0; s < n_stream; ++s) {
1244        const auto & cells = v_cells[s];
1245
1246        for (uint32_t i = 0; i < cells.size(); ++i) {
1247            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1248        }
1249    }
1250}
1251
1252struct args_set_input_kq_mask {
1253    const llama_hparams & hparams;
1254    const llama_ubatch  * ubatch;
1255
1256    const std::vector<llama_kv_cells> & v_cells;
1257    const std::vector<uint32_t>       & seq_to_stream;
1258
1259    uint32_t       n_swa;
1260    llama_swa_type swa_type;
1261
1262    int64_t n_kv;
1263    int64_t n_stream;
1264    int64_t n_tps;
1265};
1266
1267template<bool causal, bool swa, bool is_2d, bool alibi>
1268static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1269  //const auto & hparams = args.hparams;
1270    const auto & ubatch  = args.ubatch;
1271
1272    const auto & v_cells       = args.v_cells;
1273    const auto & seq_to_stream = args.seq_to_stream;
1274
1275    const uint32_t       n_swa    = args.n_swa;
1276    const llama_swa_type swa_type = args.swa_type;
1277
1278    const int64_t n_kv     = args.n_kv;
1279    const int64_t n_stream = args.n_stream;
1280    const int64_t n_tps    = args.n_tps;
1281
1282    // the min position in the batch for each sequence
1283    llama_pos seq_pos_min[LLAMA_MAX_SEQ];
1284    std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
1285
1286    for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
1287        const llama_seq_id seq_id = ubatch->seq_id[i][0];
1288
1289        seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
1290    }
1291
1292    for (uint32_t s = 0; s < n_stream; ++s) {
1293        // bookeeping of the KQ mask cells that could change for other tokens of the same sequence
1294        std::unordered_map<llama_seq_id, uint32_t>              seq_srct;
1295        std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
1296
1297        for (uint32_t ii = 0; ii < n_tps; ++ii) {
1298            const uint32_t i = s*n_tps + ii;
1299
1300            const llama_seq_id seq_id = ubatch->seq_id[i][0];
1301
1302            const auto & cells = v_cells.at(seq_to_stream[seq_id]);
1303
1304                  llama_pos p0 = -1;
1305            const llama_pos p1 = ubatch->pos[i];
1306
1307            // for M-RoPE
1308            const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1309            const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens]   : 0;
1310
1311            const uint64_t idst = n_kv*i;
1312
1313            // for tokens of the same sequence, the mask is mostly the same, so we can reuse it
1314            // the only cells that could change are the ones that are with similar positions as the
1315            //   ones in the batch (i.e. due to causal masking, SWA, etc.)
1316            // keep track of those cells and shortcut the loop to save time
1317            // note: this optimization is not compatible with Alibi position encoding
1318            // ref:  https://github.com/ggml-org/llama.cpp/pull/18842
1319            bool prev = false;
1320
1321            auto & idxs = seq_idxs[seq_id];
1322
1323            if (!alibi) {
1324                if (seq_srct.find(seq_id) != seq_srct.end()) {
1325                    const uint32_t srct = seq_srct[seq_id];
1326
1327                    const uint64_t idst_prev = n_kv*srct;
1328
1329                    std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
1330
1331                    prev = true;
1332                } else {
1333                    idxs.clear();
1334                    idxs.reserve(ubatch->n_tokens + n_swa + 32);
1335
1336                    seq_srct[seq_id] = i;
1337                }
1338            }
1339
1340            for (uint32_t jj = 0; jj < n_kv; ++jj) {
1341                uint32_t j = jj;
1342
1343                // we have an exiting mask for this sequence -> update just seq_idxs
1344                if (!alibi) {
1345                    if (prev) {
1346                        if (jj >= idxs.size()) {
1347                            break;
1348                        }
1349
1350                        j = idxs[jj];
1351                    }
1352                }
1353
1354                if (cells.is_empty(j)) {
1355                    goto skip;
1356                }
1357
1358                // mask the token if not the same sequence
1359                if (!cells.seq_has(j, seq_id)) {
1360                    goto skip;
1361                }
1362
1363                p0 = cells.pos_get(j);
1364
1365                if (!alibi) {
1366                    if (!prev) {
1367                        // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
1368                        if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
1369                            idxs.push_back(j);
1370                        }
1371                    }
1372                }
1373
1374                if (causal) {
1375                    // mask future tokens
1376                    if (p0 > p1) {
1377                        goto skip;
1378                    }
1379
1380                    // M-RoPE causal mask
1381                    if (is_2d) {
1382                        if (p0 == p1) {
1383                            const auto & p0_ext = cells.ext_get(j);
1384
1385                            if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1386                                goto skip;
1387                            }
1388                        }
1389                    }
1390                }
1391
1392                // apply SWA if any
1393                if (swa) {
1394                    if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
1395                        goto skip;
1396                    }
1397                }
1398
1399                if (alibi) {
1400                    data[idst + j] = -std::abs(p0 - p1);
1401                } else {
1402                    data[idst + j] = 0.0f;
1403                }
1404
1405                continue;
1406skip:
1407                data[idst + j] = -INFINITY;
1408            }
1409        }
1410    }
1411}
1412
1413template<bool causal, bool swa, bool is_2d>
1414static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1415    const bool alibi = args.hparams.use_alibi;
1416    if (alibi) {
1417        set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
1418    } else {
1419        set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
1420    }
1421}
1422
1423template<bool causal, bool swa>
1424static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1425    const bool is_2d = args.ubatch->is_pos_2d();
1426    if (is_2d) {
1427        set_input_kq_mask_impl<causal, swa, true> (args, data);
1428    } else {
1429        set_input_kq_mask_impl<causal, swa, false>(args, data);
1430    }
1431}
1432
1433template<bool causal>
1434static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1435    const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
1436    if (swa) {
1437        set_input_kq_mask_impl<causal, true> (args, data);
1438    } else {
1439        set_input_kq_mask_impl<causal, false>(args, data);
1440    }
1441}
1442
1443void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1444    const uint32_t n_tokens = ubatch->n_tokens;
1445
1446    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1447    float * data = (float *) dst->data;
1448
1449    const int64_t n_kv     = dst->ne[0];
1450    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1451
1452    GGML_ASSERT(n_tokens%n_stream == 0);
1453
1454    // n_tps == n_tokens_per_stream
1455    const int64_t n_tps = n_tokens/n_stream;
1456
1457    //const int64_t t_start = ggml_time_us();
1458
1459    const args_set_input_kq_mask args = {
1460        /*.hparams          =*/ hparams,
1461        /*.ubatch           =*/ ubatch,
1462        /*.v_cells          =*/ v_cells,
1463        /*.seq_to_stream    =*/ seq_to_stream,
1464        /*.n_swa            =*/ n_swa,
1465        /*.swa_type         =*/ swa_type,
1466        /*.n_kv             =*/ n_kv,
1467        /*.n_stream         =*/ n_stream,
1468        /*.n_tps            =*/ n_tps,
1469    };
1470
1471    if (causal_attn) {
1472        set_input_kq_mask_impl<true> (args, data);
1473    } else {
1474        set_input_kq_mask_impl<false>(args, data);
1475    }
1476
1477    //const int64_t t_end = ggml_time_us();
1478
1479    //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
1480}
1481
1482void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1483    const int64_t n_tokens = ubatch->n_tokens;
1484
1485    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1486    const auto & cells = v_cells[0];
1487
1488    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1489    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1490
1491    int32_t * data = (int32_t *) dst->data;
1492
1493    const int32_t n_kv = dst->ne[0];
1494
1495    for (int h = 0; h < 1; ++h) {
1496        for (int i = 0; i < n_tokens; ++i) {
1497            for (int j = 0; j < n_kv; ++j) {
1498                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
1499                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
1500
1501                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
1502            }
1503        }
1504    }
1505}
1506
1507size_t llama_kv_cache::total_size() const {
1508    size_t size = 0;
1509
1510    for (const auto & [_, buf] : ctxs_bufs) {
1511        size += ggml_backend_buffer_get_size(buf.get());
1512    }
1513
1514    return size;
1515}
1516
1517size_t llama_kv_cache::size_k_bytes() const {
1518    size_t size_k_bytes = 0;
1519
1520    for (const auto & layer : layers) {
1521        size_k_bytes += ggml_nbytes(layer.k);
1522    }
1523
1524    return size_k_bytes;
1525}
1526
1527size_t llama_kv_cache::size_v_bytes() const {
1528    size_t size_v_bytes = 0;
1529
1530    for (const auto & layer : layers) {
1531        size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
1532    }
1533
1534    return size_v_bytes;
1535}
1536
1537ggml_tensor * llama_kv_cache::build_rope_shift(
1538        const llama_cparams & cparams,
1539               ggml_context * ctx,
1540                ggml_tensor * cur,
1541                ggml_tensor * shift,
1542                ggml_tensor * factors,
1543                      float   freq_base,
1544                      float   freq_scale) const {
1545    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
1546
1547    const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
1548    const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
1549    const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
1550    const auto & yarn_attn_factor = cparams.yarn_attn_factor;
1551
1552    const auto & n_rot     = hparams.n_rot;
1553    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1554                                // @ngxson : this is a workaround
1555                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
1556                                // a normal RoPE should work, we just need to use the correct ordering
1557                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
1558                                ? LLAMA_ROPE_TYPE_NEOX
1559                                : hparams.rope_type;
1560
1561    ggml_tensor * tmp;
1562
1563    if (ggml_is_quantized(cur->type)) {
1564        // dequantize to f32 -> RoPE -> quantize back
1565        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
1566
1567        tmp = ggml_rope_ext(ctx, tmp,
1568                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1569                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1570
1571        tmp = ggml_cpy(ctx, tmp, cur);
1572    } else {
1573        // we rotate only the first n_rot dimensions
1574        tmp = ggml_rope_ext_inplace(ctx, cur,
1575                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1576                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1577    }
1578
1579    return tmp;
1580}
1581
1582class llm_graph_input_k_shift : public llm_graph_input_i {
1583public:
1584    llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
1585    virtual ~llm_graph_input_k_shift() = default;
1586
1587    void set_input(const llama_ubatch * ubatch) override;
1588
1589    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1590
1591    const llama_kv_cache * kv_self;
1592};
1593
1594void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1595    GGML_UNUSED(ubatch);
1596
1597    if (k_shift) {
1598        kv_self->set_input_k_shift(k_shift);
1599    }
1600}
1601
1602ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1603    auto * ctx = res->get_ctx();
1604    auto * gf  = res->get_gf();
1605
1606    const auto & n_embd_head_k = hparams.n_embd_head_k;
1607  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1608
1609    const auto & n_rot = hparams.n_rot;
1610
1611    const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
1612
1613    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1614
1615    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1616    ggml_set_input(inp->k_shift);
1617
1618    const auto & cparams = lctx->get_cparams();
1619
1620    for (const auto & layer : layers) {
1621        const uint32_t il = layer.il;
1622
1623        const int64_t n_head_kv    = hparams.n_head_kv(il);
1624        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1625
1626        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
1627        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
1628
1629        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
1630
1631        ggml_tensor * k =
1632            ggml_view_3d(ctx, layer.k,
1633                n_rot, n_head_kv, get_size()*n_stream,
1634                ggml_row_size(layer.k->type, n_embd_head_k),
1635                ggml_row_size(layer.k->type, n_embd_k_gqa),
1636                ggml_row_size(layer.k->type, n_embd_nope));
1637
1638        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
1639
1640        ggml_build_forward_expand(gf, cur);
1641    }
1642
1643    res->add_input(std::move(inp));
1644
1645    return gf;
1646}
1647
1648void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1649    GGML_UNUSED(flags);
1650
1651    io.write(&n_stream, sizeof(n_stream));
1652
1653    for (uint32_t s = 0; s < n_stream; ++s) {
1654        cell_ranges_t cr { s, {} };
1655
1656        uint32_t cell_count = 0;
1657
1658        const auto & cells = v_cells[s];
1659
1660        // Count the number of cells with the specified seq_id
1661        // Find all the ranges of cells with this seq id (or all, when -1)
1662        uint32_t cell_range_begin = cells.size();
1663
1664        for (uint32_t i = 0; i < cells.size(); ++i) {
1665            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1666                ++cell_count;
1667                if (cell_range_begin == cells.size()) {
1668                    cell_range_begin = i;
1669                }
1670            } else {
1671                if (cell_range_begin != cells.size()) {
1672                    cr.data.emplace_back(cell_range_begin, i);
1673                    cell_range_begin = cells.size();
1674                }
1675            }
1676        }
1677
1678        if (cell_range_begin != cells.size()) {
1679            cr.data.emplace_back(cell_range_begin, cells.size());
1680        }
1681
1682        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1683        uint32_t cell_count_check = 0;
1684        for (const auto & range : cr.data) {
1685            cell_count_check += range.second - range.first;
1686        }
1687        GGML_ASSERT(cell_count == cell_count_check);
1688
1689        io.write(&cell_count, sizeof(cell_count));
1690
1691        // skip empty streams
1692        if (cell_count == 0) {
1693            continue;
1694        }
1695
1696        state_write_meta(io, cr, seq_id);
1697        state_write_data(io, cr);
1698    }
1699}
1700
1701void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1702    GGML_UNUSED(flags);
1703
1704    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
1705
1706    uint32_t n_stream_cur;
1707    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1708    if (n_stream_cur != n_stream) {
1709        throw std::runtime_error("n_stream mismatch");
1710    }
1711
1712    for (uint32_t s = 0; s < n_stream; ++s) {
1713        uint32_t cell_count;
1714        io.read_to(&cell_count, sizeof(cell_count));
1715
1716        if (cell_count == 0) {
1717            continue;
1718        }
1719
1720        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1721
1722        slot_info sinfo;
1723
1724        bool res = true;
1725        res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
1726        res = res && state_read_data(io, strm, cell_count, sinfo);
1727
1728        if (!res) {
1729            if (seq_id == -1) {
1730                clear(true);
1731            } else {
1732                seq_rm(seq_id, -1, -1);
1733            }
1734            throw std::runtime_error("failed to restore kv cache");
1735        }
1736    }
1737}
1738
1739void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1740    const auto & cells = v_cells[cr.strm];
1741
1742    for (const auto & range : cr.data) {
1743        for (uint32_t i = range.first; i < range.second; ++i) {
1744            std::vector<llama_seq_id> seq_ids;
1745
1746            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1747                if (cur == seq_id || seq_id == -1) {
1748                    if (cells.seq_has(i, cur)) {
1749                        seq_ids.push_back(cur);
1750                    }
1751                }
1752            }
1753
1754            const llama_pos pos     = cells.pos_get(i);
1755            const uint32_t n_seq_id = seq_ids.size();
1756
1757            io.write(&pos,      sizeof(pos));
1758            io.write(&n_seq_id, sizeof(n_seq_id));
1759
1760            // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1761            //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1762
1763            for (const auto & seq_id : seq_ids) {
1764                io.write(&seq_id, sizeof(seq_id));
1765            }
1766        }
1767    }
1768}
1769
1770void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1771    const auto & cells = v_cells[cr.strm];
1772
1773    const uint32_t v_trans = this->v_trans ? 1 : 0;
1774    const uint32_t n_layer = layers.size();
1775
1776    io.write(&v_trans, sizeof(v_trans));
1777    io.write(&n_layer, sizeof(n_layer));
1778
1779    // Iterate and write all the keys first, each row is a cell
1780    // Get whole range at a time
1781    for (const auto & layer : layers) {
1782        const uint32_t il = layer.il;
1783
1784        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1785
1786        auto * k = layer.k_stream[cr.strm];
1787
1788        // Write key type
1789        const int32_t k_type_i = (int32_t) k->type;
1790        io.write(&k_type_i, sizeof(k_type_i));
1791
1792        // Write row size of key
1793        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1794        io.write(&k_size_row, sizeof(k_size_row));
1795
1796        // Read each range of cells of k_size length and write out
1797        for (const auto & range : cr.data) {
1798            const size_t range_size = range.second - range.first;
1799            const size_t buf_size = range_size * k_size_row;
1800            io.write_tensor(k, range.first * k_size_row, buf_size);
1801        }
1802    }
1803
1804    if (!v_trans) {
1805        for (const auto & layer : layers) {
1806            const uint32_t il = layer.il;
1807
1808            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1809
1810            auto * v = layer.v_stream[cr.strm];
1811            if (!v) {
1812                continue;
1813            }
1814
1815            // Write value type
1816            const int32_t v_type_i = (int32_t) v->type;
1817            io.write(&v_type_i, sizeof(v_type_i));
1818
1819            // Write row size of value
1820            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1821            io.write(&v_size_row, sizeof(v_size_row));
1822
1823            // Read each range of cells of v_size length and write out
1824            for (const auto & range : cr.data) {
1825                const size_t range_size = range.second - range.first;
1826                const size_t buf_size = range_size * v_size_row;
1827                io.write_tensor(v, range.first * v_size_row, buf_size);
1828            }
1829        }
1830    } else {
1831        // When v is transposed, we also need the element size and get the element ranges from each row
1832        const uint32_t kv_size = cells.size();
1833
1834        for (const auto & layer : layers) {
1835            const uint32_t il = layer.il;
1836
1837            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1838
1839            auto * v = layer.v_stream[cr.strm];
1840            if (!v) {
1841                continue;
1842            }
1843
1844            // Write value type
1845            const int32_t v_type_i = (int32_t) v->type;
1846            io.write(&v_type_i, sizeof(v_type_i));
1847
1848            // Write element size
1849            const uint32_t v_size_el = ggml_type_size(v->type);
1850            io.write(&v_size_el, sizeof(v_size_el));
1851
1852            // Write GQA embedding size
1853            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1854
1855            // For each row, we get the element values of each cell
1856            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1857                // Read each range of cells of v_size_el length and write out
1858                for (const auto & range : cr.data) {
1859                    const size_t range_size = range.second - range.first;
1860                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1861                    const size_t buf_size = range_size * v_size_el;
1862                    io.write_tensor(v, src_offset, buf_size);
1863                }
1864            }
1865        }
1866    }
1867}
1868
1869bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
1870    auto & cells = v_cells[strm];
1871    auto & head  = v_heads[strm];
1872
1873    if (dest_seq_id != -1) {
1874        // single sequence
1875        seq_rm(dest_seq_id, -1, -1);
1876
1877        llama_batch_allocr balloc(hparams.n_pos_per_embd());
1878
1879        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1880
1881        ubatch.seq_id_unq[0] = dest_seq_id;
1882
1883        for (uint32_t i = 0; i < cell_count; ++i) {
1884            llama_pos pos;
1885            uint32_t n_seq_id;
1886
1887            io.read_to(&pos,      sizeof(pos));
1888            io.read_to(&n_seq_id, sizeof(n_seq_id));
1889
1890            if (n_seq_id != 1) {
1891                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1892                return false;
1893            }
1894
1895            // read the sequence id, but directly discard it - we will use dest_seq_id instead
1896            {
1897                llama_seq_id seq_id;
1898                io.read_to(&seq_id, sizeof(seq_id));
1899            }
1900
1901            ubatch.pos[i]      = pos;
1902            ubatch.n_seq_id[i] = n_seq_id;
1903            ubatch.seq_id[i]   = &dest_seq_id;
1904        }
1905
1906        sinfo = find_slot(ubatch, false);
1907        if (sinfo.empty()) {
1908            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1909            return false;
1910        }
1911
1912        // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1913        //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1914        apply_ubatch(sinfo, ubatch);
1915
1916        LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
1917
1918        // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
1919        GGML_ASSERT(sinfo.n_stream() == 1);
1920        GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
1921        for (uint32_t i = 0; i < cell_count; ++i) {
1922            const uint32_t idx = sinfo.idxs[0][i];
1923            GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
1924            GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
1925        }
1926    } else {
1927        // whole KV cache restore
1928
1929        if (cell_count > cells.size()) {
1930            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1931            return false;
1932        }
1933
1934        clear(true);
1935
1936        for (uint32_t i = 0; i < cell_count; ++i) {
1937            llama_pos pos;
1938            uint32_t  n_seq_id;
1939
1940            io.read_to(&pos,      sizeof(pos));
1941            io.read_to(&n_seq_id, sizeof(n_seq_id));
1942
1943            cells.pos_set(i, pos);
1944
1945            for (uint32_t j = 0; j < n_seq_id; ++j) {
1946                llama_seq_id seq_id;
1947                io.read_to(&seq_id, sizeof(seq_id));
1948
1949                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1950                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1951                    return false;
1952                }
1953
1954                cells.seq_add(i, seq_id);
1955            }
1956        }
1957
1958        // Create contiguous slot_info for whole cache restore
1959        sinfo.s0 = strm;
1960        sinfo.s1 = strm;
1961        sinfo.resize(1);
1962        sinfo.strm[0] = strm;
1963        sinfo.idxs[0].resize(cell_count);
1964        for (uint32_t i = 0; i < cell_count; ++i) {
1965            sinfo.idxs[0][i] = i;
1966        }
1967
1968        head = 0;
1969    }
1970
1971    return true;
1972}
1973
1974bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
1975    auto & cells = v_cells[strm];
1976
1977    uint32_t v_trans;
1978    uint32_t n_layer;
1979
1980    io.read_to(&v_trans, sizeof(v_trans));
1981    io.read_to(&n_layer, sizeof(n_layer));
1982
1983    if (n_layer != layers.size()) {
1984        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1985        return false;
1986    }
1987
1988    if (cell_count > cells.size()) {
1989        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1990        return false;
1991    }
1992
1993    if (this->v_trans != (bool) v_trans) {
1994        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1995        return false;
1996    }
1997
1998    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1999    for (const auto & layer : layers) {
2000        const uint32_t il = layer.il;
2001
2002        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2003
2004        auto * k = layer.k_stream[strm];
2005
2006        // Read type of key
2007        int32_t k_type_i_ref;
2008        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2009        const int32_t k_type_i = (int32_t) k->type;
2010        if (k_type_i != k_type_i_ref) {
2011            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
2012            return false;
2013        }
2014
2015        // Read row size of key
2016        uint64_t k_size_row_ref;
2017        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2018        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
2019        if (k_size_row != k_size_row_ref) {
2020            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2021            return false;
2022        }
2023
2024        if (cell_count) {
2025            if (sinfo.is_contiguous()) {
2026                // Fast path: contiguous cells, single memcpy
2027                ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
2028            } else {
2029                // Slow path: scatter to non-contiguous positions
2030                const void * src = io.read(cell_count * k_size_row);
2031                for (uint32_t i = 0; i < cell_count; ++i) {
2032                    const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
2033                    ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
2034                }
2035            }
2036        }
2037    }
2038
2039    if (!this->v_trans) {
2040        for (const auto & layer : layers) {
2041            const uint32_t il = layer.il;
2042
2043            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2044
2045            auto * v = layer.v_stream[strm];
2046            if (!v) {
2047                continue;
2048            }
2049
2050            // Read type of value
2051            int32_t v_type_i_ref;
2052            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2053            const int32_t v_type_i = (int32_t) v->type;
2054            if (v_type_i != v_type_i_ref) {
2055                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2056                return false;
2057            }
2058
2059            // Read row size of value
2060            uint64_t v_size_row_ref;
2061            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2062            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
2063            if (v_size_row != v_size_row_ref) {
2064                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
2065                return false;
2066            }
2067
2068            if (cell_count) {
2069                if (sinfo.is_contiguous()) {
2070                    // Fast path: contiguous cells, single memcpy
2071                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
2072                } else {
2073                    // Slow path: scatter to non-contiguous positions
2074                    const void * src = io.read(cell_count * v_size_row);
2075                    for (uint32_t i = 0; i < cell_count; ++i) {
2076                        const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
2077                        ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
2078                    }
2079                }
2080            }
2081        }
2082    } else {
2083        // For each layer, read the values for each cell (transposed)
2084        for (const auto & layer : layers) {
2085            const uint32_t il = layer.il;
2086
2087            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2088
2089            auto * v = layer.v_stream[strm];
2090            if (!v) {
2091                continue;
2092            }
2093
2094            // Read type of value
2095            int32_t v_type_i_ref;
2096            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2097            const int32_t v_type_i = (int32_t) v->type;
2098            if (v_type_i != v_type_i_ref) {
2099                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2100                return false;
2101            }
2102
2103            // Read element size of value
2104            uint32_t v_size_el_ref;
2105            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2106            const size_t v_size_el = ggml_type_size(v->type);
2107            if (v_size_el != v_size_el_ref) {
2108                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
2109                return false;
2110            }
2111
2112            // Read GQA embedding size
2113            uint32_t n_embd_v_gqa_ref;
2114            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
2115            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
2116                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
2117                return false;
2118            }
2119
2120            if (cell_count) {
2121                if (sinfo.is_contiguous()) {
2122                    // Fast path: contiguous cells
2123                    const uint32_t h = sinfo.head();
2124                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2125                        const size_t dst_offset = (h + j * cells.size()) * v_size_el;
2126                        ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2127                    }
2128                } else {
2129                    // Slow path: scatter to non-contiguous positions
2130                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2131                        const void * src = io.read(cell_count * v_size_el);
2132                        for (uint32_t i = 0; i < cell_count; ++i) {
2133                            const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
2134                            ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
2135                        }
2136                    }
2137                }
2138            }
2139        }
2140    }
2141
2142    return true;
2143}
2144
2145//
2146// llama_kv_cache_context
2147//
2148
2149llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
2150
2151llama_kv_cache_context::llama_kv_cache_context(
2152        llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
2153    n_kv = kv->get_size();
2154
2155    const uint32_t n_stream = kv->get_n_stream();
2156
2157    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
2158    sinfos.resize(1);
2159    sinfos[0].s0 = 0;
2160    sinfos[0].s1 = n_stream - 1;
2161    sinfos[0].idxs.resize(n_stream);
2162    for (uint32_t s = 0; s < n_stream; ++s) {
2163        sinfos[0].strm.push_back(s);
2164        sinfos[0].idxs[s].resize(1, 0);
2165    }
2166}
2167
2168llama_kv_cache_context::llama_kv_cache_context(
2169        llama_kv_cache * kv,
2170        llama_context * lctx,
2171        bool do_shift,
2172        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
2173    if (!do_shift && this->sc_info.empty()) {
2174        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
2175    }
2176}
2177
2178llama_kv_cache_context::llama_kv_cache_context(
2179        llama_kv_cache * kv,
2180        llama_kv_cache::slot_info_vec_t sinfos,
2181        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
2182}
2183
2184llama_kv_cache_context::~llama_kv_cache_context() = default;
2185
2186bool llama_kv_cache_context::next() {
2187    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2188
2189    if (++i_cur >= ubatches.size()) {
2190        return false;
2191    }
2192
2193    return true;
2194}
2195
2196bool llama_kv_cache_context::apply() {
2197    assert(!llama_memory_status_is_fail(status));
2198
2199    // no ubatches -> this is a KV cache update
2200    if (ubatches.empty()) {
2201        kv->update(lctx, do_shift, sc_info);
2202
2203        return true;
2204    }
2205
2206    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
2207    n_kv = kv->get_n_kv(sinfos[i_cur]);
2208
2209    return true;
2210}
2211
2212llama_memory_status llama_kv_cache_context::get_status() const {
2213    return status;
2214}
2215
2216const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
2217    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2218
2219    return ubatches[i_cur];
2220}
2221
2222uint32_t llama_kv_cache_context::get_n_kv() const {
2223    return n_kv;
2224}
2225
2226ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
2227    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
2228}
2229
2230ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
2231    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
2232}
2233
2234ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
2235    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
2236}
2237
2238ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
2239    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
2240}
2241
2242ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2243    return kv->build_input_k_idxs(ctx, ubatch);
2244}
2245
2246ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2247    return kv->build_input_v_idxs(ctx, ubatch);
2248}
2249
2250void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
2251    kv->set_input_k_shift(dst);
2252}
2253
2254void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2255    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
2256}
2257
2258void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2259    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
2260}
2261
2262void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
2263    kv->set_input_kq_mask(dst, ubatch, causal_attn);
2264}
2265
2266void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2267    kv->set_input_pos_bucket(dst, ubatch);
2268}