1#include "llama-kv-cache.h"
2
3#include "llama-impl.h"
4#include "llama-io.h"
5#include "llama-model.h"
6#include "llama-context.h"
7
8#include <algorithm>
9#include <cassert>
10#include <cmath>
11#include <cstring>
12#include <limits>
13#include <map>
14#include <stdexcept>
15
16//
17// llama_kv_cache
18//
19
20llama_kv_cache::llama_kv_cache(
21 const llama_model & model,
22 ggml_type type_k,
23 ggml_type type_v,
24 bool v_trans,
25 bool offload,
26 bool unified,
27 uint32_t kv_size,
28 uint32_t n_seq_max,
29 uint32_t n_pad,
30 uint32_t n_swa,
31 llama_swa_type swa_type,
32 const layer_filter_cb & filter,
33 const layer_reuse_cb & reuse) :
34 model(model), hparams(model.hparams), v_trans(v_trans),
35 n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
36
37 GGML_ASSERT(kv_size % n_pad == 0);
38
39 const uint32_t n_layer_kv = hparams.n_layer_kv();
40
41 // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42 struct ggml_backend_buft_comparator {
43 bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44 return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
45 }
46 };
47 std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48
49 // create a context for each buffer type
50 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
51 auto it = ctx_map.find(buft);
52 if (it == ctx_map.end()) {
53 ggml_init_params params = {
54 /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
55 /*.mem_buffer =*/ NULL,
56 /*.no_alloc =*/ true,
57 };
58
59 ggml_context * ctx = ggml_init(params);
60 if (!ctx) {
61 return nullptr;
62 }
63
64 ctx_map.emplace(buft, ctx);
65
66 return ctx;
67 }
68
69 return it->second.get();
70 };
71
72 GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
73
74 v_heads.resize(n_stream);
75 for (uint32_t s = 0; s < n_stream; ++s) {
76 v_heads[s] = 0;
77 }
78
79 v_cells.resize(n_stream);
80 for (uint32_t s = 0; s < n_stream; ++s) {
81 v_cells[s].resize(kv_size);
82 }
83
84 // by default, all sequence ids are mapped to the 0th stream
85 seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
86
87 if (n_stream > 1) {
88 seq_to_stream.resize(n_stream, 0);
89 for (uint32_t s = 0; s < n_stream; ++s) {
90 seq_to_stream[s] = s;
91 }
92 }
93
94 // [TAG_V_CACHE_VARIABLE]
95 if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
96 LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
97 __func__, hparams.n_embd_v_gqa_max());
98 }
99
100 const bool is_mla = hparams.is_mla();
101
102 for (uint32_t il = 0; il < hparams.n_layer; il++) {
103 if (!hparams.has_kv(il)) {
104 LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
105 continue;
106 }
107
108 if (filter && !filter(il)) {
109 LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
110 continue;
111 }
112
113 // [TAG_V_CACHE_VARIABLE]
114 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
115 const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
116
117 const char * dev_name = "CPU";
118
119 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
120
121 if (offload) {
122 auto * dev = model.dev_layer(il);
123 buft = ggml_backend_dev_buffer_type(dev);
124
125 dev_name = ggml_backend_dev_name(dev);
126 }
127
128 LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
129
130 ggml_context * ctx = ctx_for_buft(buft);
131 if (!ctx) {
132 throw std::runtime_error("failed to create ggml context for kv cache");
133 }
134
135 const bool has_k = true;
136 const bool has_v = !is_mla;
137
138 ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
139 ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
140
141 has_k && ggml_format_name(k, "cache_k_l%d", il);
142 has_v && ggml_format_name(v, "cache_v_l%d", il);
143
144 std::vector<ggml_tensor *> k_stream;
145 std::vector<ggml_tensor *> v_stream;
146
147 for (uint32_t s = 0; s < n_stream; ++s) {
148 k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
149 v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
150 }
151
152 map_layer_ids[il] = layers.size();
153
154 layers.push_back({ il, k, v, k_stream, v_stream, });
155 }
156
157 if (reuse) {
158 LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
159
160 for (uint32_t il = 0; il < hparams.n_layer; il++) {
161 const int32_t il_reuse = reuse(il);
162
163 if (il_reuse < 0) {
164 LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
165 continue;
166 }
167
168 if (filter && !filter(il)) {
169 LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
170 continue;
171 }
172
173 GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
174
175 map_layer_ids[il] = map_layer_ids[il_reuse];
176
177 LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
178 }
179 }
180
181 // allocate tensors and initialize the buffers to avoid NaNs in the padding
182 for (auto & [buft, ctx] : ctx_map) {
183 ggml_backend_buffer_t buf;
184 if (model.hparams.no_alloc) {
185 buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
186 for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
187 t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
188 }
189 } else {
190 buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
191 }
192 if (!buf) {
193 throw std::runtime_error("failed to allocate buffer for kv cache");
194 }
195
196 LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
197
198 ggml_backend_buffer_clear(buf, 0);
199 ctxs_bufs.emplace_back(std::move(ctx), buf);
200 }
201
202 {
203 const size_t memory_size_k = size_k_bytes();
204 const size_t memory_size_v = size_v_bytes();
205
206 LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
207 (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
208 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
209 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
210 }
211
212 const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
213 debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
214}
215
216void llama_kv_cache::clear(bool data) {
217 for (uint32_t s = 0; s < n_stream; ++s) {
218 v_cells[s].reset();
219 v_heads[s] = 0;
220 }
221
222 if (data) {
223 for (auto & [_, buf] : ctxs_bufs) {
224 ggml_backend_buffer_clear(buf.get(), 0);
225 }
226 }
227}
228
229bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
230 GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
231
232 if (p0 < 0) {
233 p0 = 0;
234 }
235
236 if (p1 < 0) {
237 p1 = std::numeric_limits<llama_pos>::max();
238 }
239
240 if (seq_id >= 0) {
241 auto & cells = v_cells[seq_to_stream[seq_id]];
242 auto & head = v_heads[seq_to_stream[seq_id]];
243
244 uint32_t new_head = cells.size();
245
246 for (uint32_t i = 0; i < cells.size(); ++i) {
247 if (!cells.pos_in(i, p0, p1)) {
248 continue;
249 }
250
251 if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
252 if (new_head == cells.size()) {
253 new_head = i;
254 }
255 }
256 }
257
258 // If we freed up a slot, set head to it so searching can start there.
259 if (new_head != cells.size() && new_head < head) {
260 head = new_head;
261 }
262 } else {
263 // match any sequence
264 for (uint32_t s = 0; s < n_stream; ++s) {
265 auto & cells = v_cells[s];
266 auto & head = v_heads[s];
267
268 uint32_t new_head = cells.size();
269
270 for (uint32_t i = 0; i < cells.size(); ++i) {
271 if (!cells.pos_in(i, p0, p1)) {
272 continue;
273 }
274
275 cells.rm(i);
276
277 if (new_head == cells.size()) {
278 new_head = i;
279 }
280 }
281
282 // If we freed up a slot, set head to it so searching can start there.
283 if (new_head != cells.size() && new_head < head) {
284 head = new_head;
285 }
286 }
287 }
288
289 return true;
290}
291
292void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
293 GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
294 GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
295
296 const auto s0 = seq_to_stream[seq_id_src];
297 const auto s1 = seq_to_stream[seq_id_dst];
298
299 if (s0 == s1) {
300 // since both sequences are in the same stream, no data copy is necessary
301 // we just have to update the cells meta data
302
303 auto & cells = v_cells[s0];
304
305 if (seq_id_src == seq_id_dst) {
306 return;
307 }
308
309 if (p0 < 0) {
310 p0 = 0;
311 }
312
313 if (p1 < 0) {
314 p1 = std::numeric_limits<llama_pos>::max();
315 }
316
317 for (uint32_t i = 0; i < cells.size(); ++i) {
318 if (!cells.pos_in(i, p0, p1)) {
319 continue;
320 }
321
322 if (cells.seq_has(i, seq_id_src)) {
323 cells.seq_add(i, seq_id_dst);
324 }
325 }
326
327 return;
328 }
329
330 // cross-stream sequence copies require to copy the actual buffer data
331
332 bool is_full = true;
333
334 if (p0 > 0 && p0 + 1 < (int) get_size()) {
335 is_full = false;
336 }
337
338 if (p1 > 0 && p1 + 1 < (int) get_size()) {
339 is_full = false;
340 }
341
342 GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
343
344 // enqueue the copy operation - the buffer copy will be performed during the next update
345 sc_info.ssrc.push_back(s0);
346 sc_info.sdst.push_back(s1);
347
348 v_cells[s1].reset();
349 for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
350 if (v_cells[s0].seq_has(i, seq_id_src)) {
351 llama_pos pos = v_cells[s0].pos_get(i);
352 llama_pos shift = v_cells[s0].get_shift(i);
353
354 llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
355
356 if (shift != 0) {
357 pos -= shift;
358 assert(pos >= 0);
359 }
360
361 v_cells[s1].pos_set(i, pos);
362 v_cells[s1].seq_add(i, seq_id_dst);
363
364 if (shift != 0) {
365 v_cells[s1].pos_add(i, shift);
366 }
367
368 v_cells[s1].ext_set(i, ext);
369 }
370 }
371
372 v_heads[s1] = v_heads[s0];
373
374 //for (uint32_t s = 0; s < n_stream; ++s) {
375 // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
376 //}
377}
378
379void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
380 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
381
382 auto & cells = v_cells[seq_to_stream[seq_id]];
383 auto & head = v_heads[seq_to_stream[seq_id]];
384
385 uint32_t new_head = cells.size();
386
387 for (uint32_t i = 0; i < cells.size(); ++i) {
388 if (cells.seq_keep(i, seq_id)) {
389 if (new_head == cells.size()) {
390 new_head = i;
391 }
392 }
393 }
394
395 // If we freed up a slot, set head to it so searching can start there.
396 if (new_head != cells.size() && new_head < head) {
397 head = new_head;
398 }
399}
400
401void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
402 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
403 GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
404
405 auto & cells = v_cells[seq_to_stream[seq_id]];
406 auto & head = v_heads[seq_to_stream[seq_id]];
407
408 if (shift == 0) {
409 return;
410 }
411
412 uint32_t new_head = cells.size();
413
414 if (p0 < 0) {
415 p0 = 0;
416 }
417
418 if (p1 < 0) {
419 p1 = std::numeric_limits<llama_pos>::max();
420 }
421
422 // If there is no range then return early to avoid looping over all cells.
423 if (p0 == p1) {
424 return;
425 }
426
427 for (uint32_t i = 0; i < cells.size(); ++i) {
428 if (!cells.pos_in(i, p0, p1)) {
429 continue;
430 }
431
432 if (cells.seq_has(i, seq_id)) {
433 if (cells.pos_add(i, shift)) {
434 if (new_head == cells.size()) {
435 new_head = i;
436 }
437 }
438 }
439 }
440
441 // If we freed up a slot, set head to it so searching can start there.
442 // Otherwise we just start the next search from the beginning.
443 head = new_head != cells.size() ? new_head : 0;
444}
445
446void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
447 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
448 GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
449
450 auto & cells = v_cells[seq_to_stream[seq_id]];
451
452 if (d == 1) {
453 return;
454 }
455
456 if (p0 < 0) {
457 p0 = 0;
458 }
459
460 if (p1 < 0) {
461 p1 = std::numeric_limits<llama_pos>::max();
462 }
463
464 // If there is no range then return early to avoid looping over the cache.
465 if (p0 == p1) {
466 return;
467 }
468
469 for (uint32_t i = 0; i < cells.size(); ++i) {
470 if (!cells.pos_in(i, p0, p1)) {
471 continue;
472 }
473
474 if (cells.seq_has(i, seq_id)) {
475 cells.pos_div(i, d);
476 }
477 }
478}
479
480llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
481 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
482
483 const auto & cells = v_cells[seq_to_stream[seq_id]];
484
485 return cells.seq_pos_min(seq_id);
486}
487
488llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
489 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
490
491 const auto & cells = v_cells[seq_to_stream[seq_id]];
492
493 return cells.seq_pos_max(seq_id);
494}
495
496std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
497 std::map<ggml_backend_buffer_type_t, size_t> ret;
498 for (const auto & [ctx, buf] : ctxs_bufs) {
499 ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
500
501 if (hparams.no_alloc) {
502 GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
503 ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
504 } else {
505 // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
506 ret[buft] += ggml_backend_buffer_get_size(buf.get());
507 }
508 }
509
510 return ret;
511}
512
513llama_memory_context_ptr llama_kv_cache::init_batch(
514 llama_batch_allocr & balloc,
515 uint32_t n_ubatch,
516 bool embd_all) {
517 GGML_UNUSED(embd_all);
518
519 do {
520 balloc.split_reset();
521
522 std::vector<llama_ubatch> ubatches;
523 while (true) {
524 auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
525
526 if (ubatch.n_tokens == 0) {
527 break;
528 }
529
530 ubatches.push_back(std::move(ubatch)); // NOLINT
531 }
532
533 if (balloc.get_n_used() < balloc.get_n_tokens()) {
534 // failed to find a suitable split
535 break;
536 }
537
538 auto sinfos = prepare(ubatches);
539 if (sinfos.empty()) {
540 break;
541 }
542
543 return std::make_unique<llama_kv_cache_context>(
544 this, std::move(sinfos), std::move(ubatches));
545 } while (false);
546
547 return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
548}
549
550llama_memory_context_ptr llama_kv_cache::init_full() {
551 return std::make_unique<llama_kv_cache_context>(this);
552}
553
554llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
555 GGML_UNUSED(optimize);
556
557 bool do_shift = get_has_shift();
558
559 return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
560}
561
562llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
563 llama_kv_cache::slot_info_vec_t res;
564
565 struct state_t {
566 slot_info sinfo; // slot info for the ubatch
567
568 std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
569
570 std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
571 };
572
573 // remember the old state of the cells so we can restore it in the end
574 std::vector<state_t> states;
575
576 bool success = true;
577
578 for (const auto & ubatch : ubatches) {
579 // only find a suitable slot for the ubatch. don't modify the cells yet
580 const auto sinfo_new = find_slot(ubatch, false);
581 if (sinfo_new.empty()) {
582 success = false;
583 break;
584 }
585
586 // remeber the position that we found
587 res.push_back(sinfo_new);
588
589 // store the old state of the cells in the recovery stack
590 {
591 state_t state = { sinfo_new, v_heads, {} };
592
593 for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
594 auto & cells = v_cells[sinfo_new.strm[s]];
595
596 state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
597 }
598
599 states.push_back(std::move(state));
600 }
601
602 // now emplace the ubatch
603 apply_ubatch(sinfo_new, ubatch);
604 }
605
606 GGML_ASSERT(!states.empty() || !success);
607
608 // iterate backwards and restore the cells to their original state
609 for (auto it = states.rbegin(); it != states.rend(); ++it) {
610 const auto & sinfo = it->sinfo;
611
612 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
613 auto & cells = v_cells[sinfo.strm[s]];
614 auto & head = v_heads[sinfo.strm[s]];
615
616 cells.set(sinfo.idxs[s], it->v_cells[s]);
617 head = it->v_heads_old[s];
618 }
619 }
620
621 if (!success) {
622 return {};
623 }
624
625 return res;
626}
627
628bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
629 bool updated = false;
630
631 auto * sched = lctx->get_sched();
632
633 if (!sc_info.empty()) {
634 assert(n_stream > 1 && "stream copy should never happen with a single stream");
635
636 llama_synchronize(lctx);
637
638 const size_t n_copy = sc_info.ssrc.size();
639
640 for (size_t i = 0; i < n_copy; ++i) {
641 const auto ssrc = sc_info.ssrc[i];
642 const auto sdst = sc_info.sdst[i];
643
644 assert(ssrc < n_stream);
645 assert(sdst < n_stream);
646
647 LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
648
649 assert(ssrc != sdst);
650
651 for (uint32_t il = 0; il < layers.size(); ++il) {
652 const auto & layer = layers[il];
653
654 ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
655
656 if (layer.v_stream[ssrc]) {
657 ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
658 }
659 }
660 }
661 }
662
663 if (do_shift) {
664 if (!get_can_shift()) {
665 GGML_ABORT("The current KV cache / model configuration does not support K-shift");
666 }
667
668 LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
669
670 // apply K-shift if needed
671 if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
672 ggml_backend_sched_reset(sched);
673
674 auto * res = lctx->get_gf_res_reserve();
675
676 res->reset();
677
678 auto * gf = build_graph_shift(res, lctx);
679 if (!ggml_backend_sched_alloc_graph(sched, gf)) {
680 LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
681 return updated;
682 }
683
684 res->set_inputs(nullptr);
685
686 if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
687 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
688 return updated;
689 }
690
691 updated = true;
692 }
693
694 for (uint32_t s = 0; s < n_stream; ++s) {
695 auto & cells = v_cells[s];
696
697 cells.reset_shift();
698 }
699 }
700
701 return updated;
702}
703
704llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
705
706 if (debug > 0) {
707 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
708 const auto seq_id = ubatch.seq_id_unq[s];
709 const auto stream_id = seq_to_stream[seq_id];
710 const auto & cells = v_cells[stream_id];
711 const uint32_t head_cur = v_heads[stream_id];
712
713 LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
714 __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
715
716 if ((debug == 2 && n_swa > 0) || debug > 2) {
717 std::string ss;
718 for (uint32_t i = 0; i < cells.size(); ++i) {
719 if (cells.is_empty(i)) {
720 ss += '.';
721 } else {
722 assert(cells.seq_count(i) >= 1);
723
724 if (cells.seq_count(i) == 1) {
725 ss += std::to_string(cells.seq_get(i));
726 } else {
727 ss += 'M';
728 }
729 }
730 if (i%256 == 255) {
731 ss += " *";
732 ss += '\n';
733 }
734 }
735 LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
736 }
737
738 if ((debug == 2 && n_swa > 0) || debug > 2) {
739 std::string ss;
740 for (uint32_t i = 0; i < cells.size(); ++i) {
741 std::string cur;
742 if (cells.is_empty(i)) {
743 cur = '.';
744 } else {
745 cur = std::to_string(cells.pos_get(i));
746 }
747 const int n = cur.size();
748 for (int j = 0; j < 5 - n; ++j) {
749 cur += ' ';
750 }
751 ss += cur;
752 if (i%256 == 255) {
753 ss += " *";
754 }
755 if (i%64 == 63) {
756 ss += '\n';
757 }
758 }
759 LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
760 }
761
762 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
763 if (cells.seq_pos_min(s) < 0) {
764 continue;
765 }
766
767 LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
768 }
769 }
770 }
771
772 uint32_t n_tokens = ubatch.n_tokens;
773 uint32_t n_seqs = 1;
774
775 if (n_stream > 1) {
776 GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
777
778 n_seqs = ubatch.n_seqs_unq;
779 n_tokens = n_tokens / n_seqs;
780 }
781
782 slot_info res = {
783 /*.s0 =*/ LLAMA_MAX_SEQ,
784 /*.s1 =*/ 0,
785 /*.strm =*/ { },
786 /*.idxs =*/ { },
787 };
788
789 res.resize(n_seqs);
790
791 for (uint32_t s = 0; s < n_seqs; ++s) {
792 const auto seq_id = ubatch.seq_id_unq[s];
793
794 if (n_stream > 1) {
795 GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
796 GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
797 }
798
799 res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
800 res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
801
802 res.strm[s] = seq_to_stream[seq_id];
803 res.idxs[s].reserve(n_tokens);
804
805 const auto & cells = v_cells[seq_to_stream[seq_id]];
806
807 uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
808
809 // if we have enough unused cells before the current head ->
810 // better to start searching from the beginning of the cache, hoping to fill it
811 if (head_cur > cells.get_used() + 2*n_tokens) {
812 head_cur = 0;
813 }
814
815 if (n_tokens > cells.size()) {
816 LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
817 return { };
818 }
819
820 uint32_t n_tested = 0;
821
822 // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
823 // for non-continuous slots, we test the tokens one by one
824 const uint32_t n_test = cont ? n_tokens : 1;
825
826 while (true) {
827 if (head_cur + n_test > cells.size()) {
828 n_tested += cells.size() - head_cur;
829 head_cur = 0;
830 continue;
831 }
832
833 for (uint32_t i = 0; i < n_test; i++) {
834 const auto idx = head_cur;
835
836 head_cur++;
837 n_tested++;
838
839 //const llama_pos pos = ubatch.pos[i];
840 //const llama_seq_id seq_id = ubatch.seq_id[i][0];
841
842 // can we use this cell? either:
843 // - the cell is empty
844 // - the cell is occupied only by one sequence:
845 // - (disabled) mask causally, if the sequence is the same as the one we are inserting
846 // - mask SWA, using current max pos for that sequence in the cache
847 // always insert in the cell with minimum pos
848 bool can_use = cells.is_empty(idx);
849
850 if (!can_use && cells.seq_count(idx) == 1) {
851 const llama_pos pos_cell = cells.pos_get(idx);
852
853 // (disabled) causal mask
854 // note: it's better to purge any "future" tokens beforehand
855 //if (cells.seq_has(idx, seq_id)) {
856 // can_use = pos_cell >= pos;
857 //}
858
859 if (!can_use) {
860 const llama_seq_id seq_id_cell = cells.seq_get(idx);
861
862 // SWA mask
863 if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
864 can_use = true;
865 }
866 }
867 }
868
869 if (can_use) {
870 res.idxs[s].push_back(idx);
871 } else {
872 if (cont) {
873 break;
874 }
875 }
876 }
877
878 if (res.idxs[s].size() == n_tokens) {
879 break;
880 }
881
882 if (cont) {
883 res.idxs[s].clear();
884 }
885
886 if (n_tested >= cells.size()) {
887 //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
888 return { };
889 }
890 }
891
892 // we didn't find a suitable slot - return empty result
893 if (res.idxs[s].size() < n_tokens) {
894 return { };
895 }
896 }
897
898 assert(res.s1 >= res.s0);
899
900 return res;
901}
902
903void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
904 // keep track of the max sequence position that we would overwrite with this ubatch
905 // for non-SWA cache, this would be always empty
906 llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
907 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
908 seq_pos_max_rm[s] = -1;
909 }
910
911 assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
912
913 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
914 for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
915 const uint32_t i = s*sinfo.size() + ii;
916
917 auto & cells = v_cells[sinfo.strm[s]];
918
919 const auto idx = sinfo.idxs[s][ii];
920
921 if (!cells.is_empty(idx)) {
922 assert(cells.seq_count(idx) == 1);
923
924 const llama_seq_id seq_id = cells.seq_get(idx);
925 const llama_pos pos = cells.pos_get(idx);
926
927 seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
928
929 cells.rm(idx);
930 }
931
932 cells.pos_set(idx, ubatch.pos[i]);
933
934 if (ubatch.is_pos_2d()) {
935 llama_kv_cell_ext ext {
936 /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
937 /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
938 };
939 cells.ext_set(idx, ext);
940 }
941
942 for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
943 cells.seq_add(idx, ubatch.seq_id[i][s]);
944 }
945 }
946 }
947
948 // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
949 // will be present in the cache. so we have to purge any position which is less than those we would overwrite
950 // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
951 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
952 if (seq_pos_max_rm[s] == -1) {
953 continue;
954 }
955
956 GGML_ASSERT(s < seq_to_stream.size());
957
958 auto & cells = v_cells[seq_to_stream[s]];
959
960 if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
961 LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
962 __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
963
964 seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
965 }
966 }
967
968 // move the head at the end of the slot
969 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
970 auto & head = v_heads[sinfo.strm[s]];
971
972 head = sinfo.idxs[s].back() + 1;
973 }
974}
975
976bool llama_kv_cache::get_can_shift() const {
977 // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot.
978 if (model.arch == LLM_ARCH_STEP35) {
979 return false;
980 }
981 return true;
982}
983
984uint32_t llama_kv_cache::get_size() const {
985 const auto & cells = v_cells[seq_to_stream[0]];
986
987 return cells.size();
988}
989
990uint32_t llama_kv_cache::get_n_stream() const {
991 return n_stream;
992}
993
994bool llama_kv_cache::get_has_shift() const {
995 bool result = false;
996
997 for (uint32_t s = 0; s < n_stream; ++s) {
998 result |= v_cells[s].get_has_shift();
999 }
1000
1001 return result;
1002}
1003
1004uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
1005 uint32_t result = 0;
1006
1007 // pad the n_kv value so that the graph remains constant across batches and can be reused
1008 // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
1009 const uint32_t n_pad_cur = std::max(n_pad, 256u);
1010
1011 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1012 const auto & cells = v_cells[sinfo.strm[s]];
1013
1014 result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
1015 }
1016
1017 return result;
1018}
1019
1020ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1021 const int32_t ikv = map_layer_ids.at(il);
1022
1023 auto * k = layers[ikv].k;
1024
1025 const uint64_t kv_size = get_size();
1026 const uint64_t n_embd_k_gqa = k->ne[0];
1027
1028 assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1029
1030 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1031
1032 return ggml_view_4d(ctx, k,
1033 hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
1034 ggml_row_size(k->type, hparams.n_embd_head_k),
1035 ggml_row_size(k->type, n_embd_k_gqa),
1036 ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1037 ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
1038}
1039
1040ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1041 const int32_t ikv = map_layer_ids.at(il);
1042
1043 auto * v = layers[ikv].v;
1044
1045 const uint64_t kv_size = get_size();
1046 const uint64_t n_embd_v_gqa = v->ne[0];
1047
1048 // [TAG_V_CACHE_VARIABLE]
1049 assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1050
1051 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1052
1053 if (!v_trans) {
1054 // note: v->nb[1] <= v->nb[2]
1055 return ggml_view_4d(ctx, v,
1056 hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1057 ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1058 ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1059 ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1060 ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
1061 }
1062
1063 // note: v->nb[1] > v->nb[2]
1064 return ggml_view_4d(ctx, v,
1065 n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1066 ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1067 ggml_row_size(v->type, kv_size), // v->nb[2]
1068 ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1069 ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
1070}
1071
1072ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1073 GGML_UNUSED(sinfo);
1074
1075 const int32_t ikv = map_layer_ids.at(il);
1076
1077 ggml_tensor * k = layers[ikv].k;
1078
1079 const int64_t n_embd_head = k_cur->ne[0];
1080 const int64_t n_head = k_cur->ne[1];
1081 const int64_t n_tokens = k_cur->ne[2];
1082
1083 const int64_t n_embd_gqa = n_embd_head*n_head;
1084
1085 // we can merge dims 0 and 1
1086 // TODO: add ggml helper function for this?
1087 GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
1088
1089 k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
1090
1091 const int64_t n_stream = k->ne[2];
1092
1093 if (n_stream > 1) {
1094 const int64_t kv_size = get_size();
1095
1096 assert(n_embd_gqa == k->ne[0]);
1097 assert(kv_size == k->ne[1]);
1098
1099 // merge the buffer across all streams because the idxs are global
1100 k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
1101 }
1102
1103 // store the current K values into the cache
1104 return ggml_set_rows(ctx, k, k_cur, k_idxs);
1105}
1106
1107ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1108 GGML_UNUSED(sinfo);
1109
1110 const int32_t ikv = map_layer_ids.at(il);
1111
1112 auto * v = layers[ikv].v;
1113
1114 const int64_t n_embd_head = v_cur->ne[0];
1115 const int64_t n_head = v_cur->ne[1];
1116 const int64_t n_tokens = v_cur->ne[2];
1117
1118 const int64_t n_embd_gqa = n_embd_head*n_head;
1119
1120 // we can merge dims 0 and 1
1121 GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
1122
1123 const int64_t n_stream = v->ne[2];
1124
1125 // take this branch when FA is enabled (the V cache is not transposed)
1126 if (!v_trans) {
1127 v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
1128
1129 if (n_stream > 1) {
1130 const int64_t kv_size = get_size();
1131
1132 assert(n_embd_gqa == v->ne[0]);
1133 assert(kv_size == v->ne[1]);
1134
1135 // merge the buffer across all streams because the idxs are global
1136 v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
1137 }
1138
1139 return ggml_set_rows(ctx, v, v_cur, v_idxs);
1140 }
1141
1142 if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
1143 // we can merge dims 0, 1 and 2
1144 v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
1145 } else {
1146 // otherwise -> make a copy to get contiguous data
1147 v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
1148 }
1149
1150 // [TAG_V_CACHE_VARIABLE]
1151 if (n_embd_gqa < v->ne[0]) {
1152 v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
1153 }
1154
1155 // in this branch the v_idxs are constructed in such a way that each row is a single head element
1156 ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
1157
1158 v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
1159
1160 return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
1161}
1162
1163ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1164 const uint32_t n_tokens = ubatch.n_tokens;
1165
1166 ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1167
1168 ggml_set_input(k_idxs);
1169
1170 return k_idxs;
1171}
1172
1173ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1174 const uint32_t n_tokens = ubatch.n_tokens;
1175
1176 ggml_tensor * v_idxs;
1177
1178 if (!v_trans) {
1179 v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1180 } else {
1181 v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1182 }
1183
1184 ggml_set_input(v_idxs);
1185
1186 return v_idxs;
1187}
1188
1189void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1190 const uint32_t n_tokens = ubatch->n_tokens;
1191 GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1192
1193 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1194 int64_t * data = (int64_t *) dst->data;
1195
1196 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1197 const int64_t offs = sinfo.strm[s]*get_size();
1198
1199 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1200 data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1201 }
1202 }
1203}
1204
1205void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1206 const uint32_t n_tokens = ubatch->n_tokens;
1207 GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1208
1209 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1210 int64_t * data = (int64_t *) dst->data;
1211
1212 if (!v_trans) {
1213 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1214 const int64_t offs = sinfo.strm[s]*get_size();
1215
1216 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1217 data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1218 }
1219 }
1220 } else {
1221 // note: the V cache is transposed when not using flash attention
1222 const int64_t kv_size = get_size();
1223
1224 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1225
1226 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1227 const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1228
1229 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1230 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1231 data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1232 }
1233 }
1234 }
1235 }
1236}
1237
1238void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
1239 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1240
1241 int32_t * data = (int32_t *) dst->data;
1242
1243 for (uint32_t s = 0; s < n_stream; ++s) {
1244 const auto & cells = v_cells[s];
1245
1246 for (uint32_t i = 0; i < cells.size(); ++i) {
1247 data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1248 }
1249 }
1250}
1251
1252struct args_set_input_kq_mask {
1253 const llama_hparams & hparams;
1254 const llama_ubatch * ubatch;
1255
1256 const std::vector<llama_kv_cells> & v_cells;
1257 const std::vector<uint32_t> & seq_to_stream;
1258
1259 uint32_t n_swa;
1260 llama_swa_type swa_type;
1261
1262 int64_t n_kv;
1263 int64_t n_stream;
1264 int64_t n_tps;
1265};
1266
1267template<bool causal, bool swa, bool is_2d, bool alibi>
1268static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1269 //const auto & hparams = args.hparams;
1270 const auto & ubatch = args.ubatch;
1271
1272 const auto & v_cells = args.v_cells;
1273 const auto & seq_to_stream = args.seq_to_stream;
1274
1275 const uint32_t n_swa = args.n_swa;
1276 const llama_swa_type swa_type = args.swa_type;
1277
1278 const int64_t n_kv = args.n_kv;
1279 const int64_t n_stream = args.n_stream;
1280 const int64_t n_tps = args.n_tps;
1281
1282 // the min position in the batch for each sequence
1283 llama_pos seq_pos_min[LLAMA_MAX_SEQ];
1284 std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
1285
1286 for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
1287 const llama_seq_id seq_id = ubatch->seq_id[i][0];
1288
1289 seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
1290 }
1291
1292 for (uint32_t s = 0; s < n_stream; ++s) {
1293 // bookeeping of the KQ mask cells that could change for other tokens of the same sequence
1294 std::unordered_map<llama_seq_id, uint32_t> seq_srct;
1295 std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
1296
1297 for (uint32_t ii = 0; ii < n_tps; ++ii) {
1298 const uint32_t i = s*n_tps + ii;
1299
1300 const llama_seq_id seq_id = ubatch->seq_id[i][0];
1301
1302 const auto & cells = v_cells.at(seq_to_stream[seq_id]);
1303
1304 llama_pos p0 = -1;
1305 const llama_pos p1 = ubatch->pos[i];
1306
1307 // for M-RoPE
1308 const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1309 const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1310
1311 const uint64_t idst = n_kv*i;
1312
1313 // for tokens of the same sequence, the mask is mostly the same, so we can reuse it
1314 // the only cells that could change are the ones that are with similar positions as the
1315 // ones in the batch (i.e. due to causal masking, SWA, etc.)
1316 // keep track of those cells and shortcut the loop to save time
1317 // note: this optimization is not compatible with Alibi position encoding
1318 // ref: https://github.com/ggml-org/llama.cpp/pull/18842
1319 bool prev = false;
1320
1321 auto & idxs = seq_idxs[seq_id];
1322
1323 if (!alibi) {
1324 if (seq_srct.find(seq_id) != seq_srct.end()) {
1325 const uint32_t srct = seq_srct[seq_id];
1326
1327 const uint64_t idst_prev = n_kv*srct;
1328
1329 std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
1330
1331 prev = true;
1332 } else {
1333 idxs.clear();
1334 idxs.reserve(ubatch->n_tokens + n_swa + 32);
1335
1336 seq_srct[seq_id] = i;
1337 }
1338 }
1339
1340 for (uint32_t jj = 0; jj < n_kv; ++jj) {
1341 uint32_t j = jj;
1342
1343 // we have an exiting mask for this sequence -> update just seq_idxs
1344 if (!alibi) {
1345 if (prev) {
1346 if (jj >= idxs.size()) {
1347 break;
1348 }
1349
1350 j = idxs[jj];
1351 }
1352 }
1353
1354 if (cells.is_empty(j)) {
1355 goto skip;
1356 }
1357
1358 // mask the token if not the same sequence
1359 if (!cells.seq_has(j, seq_id)) {
1360 goto skip;
1361 }
1362
1363 p0 = cells.pos_get(j);
1364
1365 if (!alibi) {
1366 if (!prev) {
1367 // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
1368 if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
1369 idxs.push_back(j);
1370 }
1371 }
1372 }
1373
1374 if (causal) {
1375 // mask future tokens
1376 if (p0 > p1) {
1377 goto skip;
1378 }
1379
1380 // M-RoPE causal mask
1381 if (is_2d) {
1382 if (p0 == p1) {
1383 const auto & p0_ext = cells.ext_get(j);
1384
1385 if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1386 goto skip;
1387 }
1388 }
1389 }
1390 }
1391
1392 // apply SWA if any
1393 if (swa) {
1394 if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
1395 goto skip;
1396 }
1397 }
1398
1399 if (alibi) {
1400 data[idst + j] = -std::abs(p0 - p1);
1401 } else {
1402 data[idst + j] = 0.0f;
1403 }
1404
1405 continue;
1406skip:
1407 data[idst + j] = -INFINITY;
1408 }
1409 }
1410 }
1411}
1412
1413template<bool causal, bool swa, bool is_2d>
1414static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1415 const bool alibi = args.hparams.use_alibi;
1416 if (alibi) {
1417 set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
1418 } else {
1419 set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
1420 }
1421}
1422
1423template<bool causal, bool swa>
1424static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1425 const bool is_2d = args.ubatch->is_pos_2d();
1426 if (is_2d) {
1427 set_input_kq_mask_impl<causal, swa, true> (args, data);
1428 } else {
1429 set_input_kq_mask_impl<causal, swa, false>(args, data);
1430 }
1431}
1432
1433template<bool causal>
1434static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1435 const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
1436 if (swa) {
1437 set_input_kq_mask_impl<causal, true> (args, data);
1438 } else {
1439 set_input_kq_mask_impl<causal, false>(args, data);
1440 }
1441}
1442
1443void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1444 const uint32_t n_tokens = ubatch->n_tokens;
1445
1446 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1447 float * data = (float *) dst->data;
1448
1449 const int64_t n_kv = dst->ne[0];
1450 const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1451
1452 GGML_ASSERT(n_tokens%n_stream == 0);
1453
1454 // n_tps == n_tokens_per_stream
1455 const int64_t n_tps = n_tokens/n_stream;
1456
1457 //const int64_t t_start = ggml_time_us();
1458
1459 const args_set_input_kq_mask args = {
1460 /*.hparams =*/ hparams,
1461 /*.ubatch =*/ ubatch,
1462 /*.v_cells =*/ v_cells,
1463 /*.seq_to_stream =*/ seq_to_stream,
1464 /*.n_swa =*/ n_swa,
1465 /*.swa_type =*/ swa_type,
1466 /*.n_kv =*/ n_kv,
1467 /*.n_stream =*/ n_stream,
1468 /*.n_tps =*/ n_tps,
1469 };
1470
1471 if (causal_attn) {
1472 set_input_kq_mask_impl<true> (args, data);
1473 } else {
1474 set_input_kq_mask_impl<false>(args, data);
1475 }
1476
1477 //const int64_t t_end = ggml_time_us();
1478
1479 //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
1480}
1481
1482void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1483 const int64_t n_tokens = ubatch->n_tokens;
1484
1485 GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1486 const auto & cells = v_cells[0];
1487
1488 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1489 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1490
1491 int32_t * data = (int32_t *) dst->data;
1492
1493 const int32_t n_kv = dst->ne[0];
1494
1495 for (int h = 0; h < 1; ++h) {
1496 for (int i = 0; i < n_tokens; ++i) {
1497 for (int j = 0; j < n_kv; ++j) {
1498 // the position when the cells is empty is irrelevant - it will be masked out later in the attention
1499 const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
1500
1501 data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
1502 }
1503 }
1504 }
1505}
1506
1507size_t llama_kv_cache::total_size() const {
1508 size_t size = 0;
1509
1510 for (const auto & [_, buf] : ctxs_bufs) {
1511 size += ggml_backend_buffer_get_size(buf.get());
1512 }
1513
1514 return size;
1515}
1516
1517size_t llama_kv_cache::size_k_bytes() const {
1518 size_t size_k_bytes = 0;
1519
1520 for (const auto & layer : layers) {
1521 size_k_bytes += ggml_nbytes(layer.k);
1522 }
1523
1524 return size_k_bytes;
1525}
1526
1527size_t llama_kv_cache::size_v_bytes() const {
1528 size_t size_v_bytes = 0;
1529
1530 for (const auto & layer : layers) {
1531 size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
1532 }
1533
1534 return size_v_bytes;
1535}
1536
1537ggml_tensor * llama_kv_cache::build_rope_shift(
1538 const llama_cparams & cparams,
1539 ggml_context * ctx,
1540 ggml_tensor * cur,
1541 ggml_tensor * shift,
1542 ggml_tensor * factors,
1543 float freq_base,
1544 float freq_scale) const {
1545 const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
1546
1547 const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1548 const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1549 const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1550 const auto & yarn_attn_factor = cparams.yarn_attn_factor;
1551
1552 const auto & n_rot = hparams.n_rot;
1553 const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1554 // @ngxson : this is a workaround
1555 // for M-RoPE, we want to rotate the whole vector when doing KV shift
1556 // a normal RoPE should work, we just need to use the correct ordering
1557 // ref: https://github.com/ggml-org/llama.cpp/pull/13870
1558 ? LLAMA_ROPE_TYPE_NEOX
1559 : hparams.rope_type;
1560
1561 ggml_tensor * tmp;
1562
1563 if (ggml_is_quantized(cur->type)) {
1564 // dequantize to f32 -> RoPE -> quantize back
1565 tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
1566
1567 tmp = ggml_rope_ext(ctx, tmp,
1568 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1569 yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1570
1571 tmp = ggml_cpy(ctx, tmp, cur);
1572 } else {
1573 // we rotate only the first n_rot dimensions
1574 tmp = ggml_rope_ext_inplace(ctx, cur,
1575 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1576 yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1577 }
1578
1579 return tmp;
1580}
1581
1582class llm_graph_input_k_shift : public llm_graph_input_i {
1583public:
1584 llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
1585 virtual ~llm_graph_input_k_shift() = default;
1586
1587 void set_input(const llama_ubatch * ubatch) override;
1588
1589 ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1590
1591 const llama_kv_cache * kv_self;
1592};
1593
1594void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1595 GGML_UNUSED(ubatch);
1596
1597 if (k_shift) {
1598 kv_self->set_input_k_shift(k_shift);
1599 }
1600}
1601
1602ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1603 auto * ctx = res->get_ctx();
1604 auto * gf = res->get_gf();
1605
1606 const auto & n_embd_head_k = hparams.n_embd_head_k;
1607 //const auto & n_embd_head_v = hparams.n_embd_head_v;
1608
1609 const auto & n_rot = hparams.n_rot;
1610
1611 const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
1612
1613 auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1614
1615 inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1616 ggml_set_input(inp->k_shift);
1617
1618 const auto & cparams = lctx->get_cparams();
1619
1620 for (const auto & layer : layers) {
1621 const uint32_t il = layer.il;
1622
1623 const int64_t n_head_kv = hparams.n_head_kv(il);
1624 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1625
1626 const float freq_base_l = model.get_rope_freq_base (cparams, il);
1627 const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
1628
1629 ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
1630
1631 ggml_tensor * k =
1632 ggml_view_3d(ctx, layer.k,
1633 n_rot, n_head_kv, get_size()*n_stream,
1634 ggml_row_size(layer.k->type, n_embd_head_k),
1635 ggml_row_size(layer.k->type, n_embd_k_gqa),
1636 ggml_row_size(layer.k->type, n_embd_nope));
1637
1638 ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
1639
1640 ggml_build_forward_expand(gf, cur);
1641 }
1642
1643 res->add_input(std::move(inp));
1644
1645 return gf;
1646}
1647
1648void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1649 GGML_UNUSED(flags);
1650
1651 io.write(&n_stream, sizeof(n_stream));
1652
1653 for (uint32_t s = 0; s < n_stream; ++s) {
1654 cell_ranges_t cr { s, {} };
1655
1656 uint32_t cell_count = 0;
1657
1658 const auto & cells = v_cells[s];
1659
1660 // Count the number of cells with the specified seq_id
1661 // Find all the ranges of cells with this seq id (or all, when -1)
1662 uint32_t cell_range_begin = cells.size();
1663
1664 for (uint32_t i = 0; i < cells.size(); ++i) {
1665 if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1666 ++cell_count;
1667 if (cell_range_begin == cells.size()) {
1668 cell_range_begin = i;
1669 }
1670 } else {
1671 if (cell_range_begin != cells.size()) {
1672 cr.data.emplace_back(cell_range_begin, i);
1673 cell_range_begin = cells.size();
1674 }
1675 }
1676 }
1677
1678 if (cell_range_begin != cells.size()) {
1679 cr.data.emplace_back(cell_range_begin, cells.size());
1680 }
1681
1682 // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1683 uint32_t cell_count_check = 0;
1684 for (const auto & range : cr.data) {
1685 cell_count_check += range.second - range.first;
1686 }
1687 GGML_ASSERT(cell_count == cell_count_check);
1688
1689 io.write(&cell_count, sizeof(cell_count));
1690
1691 // skip empty streams
1692 if (cell_count == 0) {
1693 continue;
1694 }
1695
1696 state_write_meta(io, cr, seq_id);
1697 state_write_data(io, cr);
1698 }
1699}
1700
1701void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1702 GGML_UNUSED(flags);
1703
1704 GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
1705
1706 uint32_t n_stream_cur;
1707 io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1708 if (n_stream_cur != n_stream) {
1709 throw std::runtime_error("n_stream mismatch");
1710 }
1711
1712 for (uint32_t s = 0; s < n_stream; ++s) {
1713 uint32_t cell_count;
1714 io.read_to(&cell_count, sizeof(cell_count));
1715
1716 if (cell_count == 0) {
1717 continue;
1718 }
1719
1720 const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1721
1722 slot_info sinfo;
1723
1724 bool res = true;
1725 res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
1726 res = res && state_read_data(io, strm, cell_count, sinfo);
1727
1728 if (!res) {
1729 if (seq_id == -1) {
1730 clear(true);
1731 } else {
1732 seq_rm(seq_id, -1, -1);
1733 }
1734 throw std::runtime_error("failed to restore kv cache");
1735 }
1736 }
1737}
1738
1739void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1740 const auto & cells = v_cells[cr.strm];
1741
1742 for (const auto & range : cr.data) {
1743 for (uint32_t i = range.first; i < range.second; ++i) {
1744 std::vector<llama_seq_id> seq_ids;
1745
1746 for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1747 if (cur == seq_id || seq_id == -1) {
1748 if (cells.seq_has(i, cur)) {
1749 seq_ids.push_back(cur);
1750 }
1751 }
1752 }
1753
1754 const llama_pos pos = cells.pos_get(i);
1755 const uint32_t n_seq_id = seq_ids.size();
1756
1757 io.write(&pos, sizeof(pos));
1758 io.write(&n_seq_id, sizeof(n_seq_id));
1759
1760 // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1761 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1762
1763 for (const auto & seq_id : seq_ids) {
1764 io.write(&seq_id, sizeof(seq_id));
1765 }
1766 }
1767 }
1768}
1769
1770void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1771 const auto & cells = v_cells[cr.strm];
1772
1773 const uint32_t v_trans = this->v_trans ? 1 : 0;
1774 const uint32_t n_layer = layers.size();
1775
1776 io.write(&v_trans, sizeof(v_trans));
1777 io.write(&n_layer, sizeof(n_layer));
1778
1779 // Iterate and write all the keys first, each row is a cell
1780 // Get whole range at a time
1781 for (const auto & layer : layers) {
1782 const uint32_t il = layer.il;
1783
1784 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1785
1786 auto * k = layer.k_stream[cr.strm];
1787
1788 // Write key type
1789 const int32_t k_type_i = (int32_t) k->type;
1790 io.write(&k_type_i, sizeof(k_type_i));
1791
1792 // Write row size of key
1793 const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1794 io.write(&k_size_row, sizeof(k_size_row));
1795
1796 // Read each range of cells of k_size length and write out
1797 for (const auto & range : cr.data) {
1798 const size_t range_size = range.second - range.first;
1799 const size_t buf_size = range_size * k_size_row;
1800 io.write_tensor(k, range.first * k_size_row, buf_size);
1801 }
1802 }
1803
1804 if (!v_trans) {
1805 for (const auto & layer : layers) {
1806 const uint32_t il = layer.il;
1807
1808 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1809
1810 auto * v = layer.v_stream[cr.strm];
1811 if (!v) {
1812 continue;
1813 }
1814
1815 // Write value type
1816 const int32_t v_type_i = (int32_t) v->type;
1817 io.write(&v_type_i, sizeof(v_type_i));
1818
1819 // Write row size of value
1820 const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1821 io.write(&v_size_row, sizeof(v_size_row));
1822
1823 // Read each range of cells of v_size length and write out
1824 for (const auto & range : cr.data) {
1825 const size_t range_size = range.second - range.first;
1826 const size_t buf_size = range_size * v_size_row;
1827 io.write_tensor(v, range.first * v_size_row, buf_size);
1828 }
1829 }
1830 } else {
1831 // When v is transposed, we also need the element size and get the element ranges from each row
1832 const uint32_t kv_size = cells.size();
1833
1834 for (const auto & layer : layers) {
1835 const uint32_t il = layer.il;
1836
1837 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1838
1839 auto * v = layer.v_stream[cr.strm];
1840 if (!v) {
1841 continue;
1842 }
1843
1844 // Write value type
1845 const int32_t v_type_i = (int32_t) v->type;
1846 io.write(&v_type_i, sizeof(v_type_i));
1847
1848 // Write element size
1849 const uint32_t v_size_el = ggml_type_size(v->type);
1850 io.write(&v_size_el, sizeof(v_size_el));
1851
1852 // Write GQA embedding size
1853 io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1854
1855 // For each row, we get the element values of each cell
1856 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1857 // Read each range of cells of v_size_el length and write out
1858 for (const auto & range : cr.data) {
1859 const size_t range_size = range.second - range.first;
1860 const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1861 const size_t buf_size = range_size * v_size_el;
1862 io.write_tensor(v, src_offset, buf_size);
1863 }
1864 }
1865 }
1866 }
1867}
1868
1869bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
1870 auto & cells = v_cells[strm];
1871 auto & head = v_heads[strm];
1872
1873 if (dest_seq_id != -1) {
1874 // single sequence
1875 seq_rm(dest_seq_id, -1, -1);
1876
1877 llama_batch_allocr balloc(hparams.n_pos_per_embd());
1878
1879 llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1880
1881 ubatch.seq_id_unq[0] = dest_seq_id;
1882
1883 for (uint32_t i = 0; i < cell_count; ++i) {
1884 llama_pos pos;
1885 uint32_t n_seq_id;
1886
1887 io.read_to(&pos, sizeof(pos));
1888 io.read_to(&n_seq_id, sizeof(n_seq_id));
1889
1890 if (n_seq_id != 1) {
1891 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1892 return false;
1893 }
1894
1895 // read the sequence id, but directly discard it - we will use dest_seq_id instead
1896 {
1897 llama_seq_id seq_id;
1898 io.read_to(&seq_id, sizeof(seq_id));
1899 }
1900
1901 ubatch.pos[i] = pos;
1902 ubatch.n_seq_id[i] = n_seq_id;
1903 ubatch.seq_id[i] = &dest_seq_id;
1904 }
1905
1906 sinfo = find_slot(ubatch, false);
1907 if (sinfo.empty()) {
1908 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1909 return false;
1910 }
1911
1912 // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1913 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1914 apply_ubatch(sinfo, ubatch);
1915
1916 LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
1917
1918 // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
1919 GGML_ASSERT(sinfo.n_stream() == 1);
1920 GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
1921 for (uint32_t i = 0; i < cell_count; ++i) {
1922 const uint32_t idx = sinfo.idxs[0][i];
1923 GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
1924 GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
1925 }
1926 } else {
1927 // whole KV cache restore
1928
1929 if (cell_count > cells.size()) {
1930 LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1931 return false;
1932 }
1933
1934 clear(true);
1935
1936 for (uint32_t i = 0; i < cell_count; ++i) {
1937 llama_pos pos;
1938 uint32_t n_seq_id;
1939
1940 io.read_to(&pos, sizeof(pos));
1941 io.read_to(&n_seq_id, sizeof(n_seq_id));
1942
1943 cells.pos_set(i, pos);
1944
1945 for (uint32_t j = 0; j < n_seq_id; ++j) {
1946 llama_seq_id seq_id;
1947 io.read_to(&seq_id, sizeof(seq_id));
1948
1949 if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1950 LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1951 return false;
1952 }
1953
1954 cells.seq_add(i, seq_id);
1955 }
1956 }
1957
1958 // Create contiguous slot_info for whole cache restore
1959 sinfo.s0 = strm;
1960 sinfo.s1 = strm;
1961 sinfo.resize(1);
1962 sinfo.strm[0] = strm;
1963 sinfo.idxs[0].resize(cell_count);
1964 for (uint32_t i = 0; i < cell_count; ++i) {
1965 sinfo.idxs[0][i] = i;
1966 }
1967
1968 head = 0;
1969 }
1970
1971 return true;
1972}
1973
1974bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
1975 auto & cells = v_cells[strm];
1976
1977 uint32_t v_trans;
1978 uint32_t n_layer;
1979
1980 io.read_to(&v_trans, sizeof(v_trans));
1981 io.read_to(&n_layer, sizeof(n_layer));
1982
1983 if (n_layer != layers.size()) {
1984 LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1985 return false;
1986 }
1987
1988 if (cell_count > cells.size()) {
1989 LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1990 return false;
1991 }
1992
1993 if (this->v_trans != (bool) v_trans) {
1994 LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1995 return false;
1996 }
1997
1998 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1999 for (const auto & layer : layers) {
2000 const uint32_t il = layer.il;
2001
2002 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2003
2004 auto * k = layer.k_stream[strm];
2005
2006 // Read type of key
2007 int32_t k_type_i_ref;
2008 io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2009 const int32_t k_type_i = (int32_t) k->type;
2010 if (k_type_i != k_type_i_ref) {
2011 LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
2012 return false;
2013 }
2014
2015 // Read row size of key
2016 uint64_t k_size_row_ref;
2017 io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2018 const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
2019 if (k_size_row != k_size_row_ref) {
2020 LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2021 return false;
2022 }
2023
2024 if (cell_count) {
2025 if (sinfo.is_contiguous()) {
2026 // Fast path: contiguous cells, single memcpy
2027 ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
2028 } else {
2029 // Slow path: scatter to non-contiguous positions
2030 const void * src = io.read(cell_count * k_size_row);
2031 for (uint32_t i = 0; i < cell_count; ++i) {
2032 const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
2033 ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
2034 }
2035 }
2036 }
2037 }
2038
2039 if (!this->v_trans) {
2040 for (const auto & layer : layers) {
2041 const uint32_t il = layer.il;
2042
2043 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2044
2045 auto * v = layer.v_stream[strm];
2046 if (!v) {
2047 continue;
2048 }
2049
2050 // Read type of value
2051 int32_t v_type_i_ref;
2052 io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2053 const int32_t v_type_i = (int32_t) v->type;
2054 if (v_type_i != v_type_i_ref) {
2055 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2056 return false;
2057 }
2058
2059 // Read row size of value
2060 uint64_t v_size_row_ref;
2061 io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2062 const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
2063 if (v_size_row != v_size_row_ref) {
2064 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
2065 return false;
2066 }
2067
2068 if (cell_count) {
2069 if (sinfo.is_contiguous()) {
2070 // Fast path: contiguous cells, single memcpy
2071 ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
2072 } else {
2073 // Slow path: scatter to non-contiguous positions
2074 const void * src = io.read(cell_count * v_size_row);
2075 for (uint32_t i = 0; i < cell_count; ++i) {
2076 const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
2077 ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
2078 }
2079 }
2080 }
2081 }
2082 } else {
2083 // For each layer, read the values for each cell (transposed)
2084 for (const auto & layer : layers) {
2085 const uint32_t il = layer.il;
2086
2087 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2088
2089 auto * v = layer.v_stream[strm];
2090 if (!v) {
2091 continue;
2092 }
2093
2094 // Read type of value
2095 int32_t v_type_i_ref;
2096 io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2097 const int32_t v_type_i = (int32_t) v->type;
2098 if (v_type_i != v_type_i_ref) {
2099 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2100 return false;
2101 }
2102
2103 // Read element size of value
2104 uint32_t v_size_el_ref;
2105 io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2106 const size_t v_size_el = ggml_type_size(v->type);
2107 if (v_size_el != v_size_el_ref) {
2108 LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
2109 return false;
2110 }
2111
2112 // Read GQA embedding size
2113 uint32_t n_embd_v_gqa_ref;
2114 io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
2115 if (n_embd_v_gqa != n_embd_v_gqa_ref) {
2116 LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
2117 return false;
2118 }
2119
2120 if (cell_count) {
2121 if (sinfo.is_contiguous()) {
2122 // Fast path: contiguous cells
2123 const uint32_t h = sinfo.head();
2124 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2125 const size_t dst_offset = (h + j * cells.size()) * v_size_el;
2126 ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2127 }
2128 } else {
2129 // Slow path: scatter to non-contiguous positions
2130 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2131 const void * src = io.read(cell_count * v_size_el);
2132 for (uint32_t i = 0; i < cell_count; ++i) {
2133 const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
2134 ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
2135 }
2136 }
2137 }
2138 }
2139 }
2140 }
2141
2142 return true;
2143}
2144
2145//
2146// llama_kv_cache_context
2147//
2148
2149llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
2150
2151llama_kv_cache_context::llama_kv_cache_context(
2152 llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
2153 n_kv = kv->get_size();
2154
2155 const uint32_t n_stream = kv->get_n_stream();
2156
2157 // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
2158 sinfos.resize(1);
2159 sinfos[0].s0 = 0;
2160 sinfos[0].s1 = n_stream - 1;
2161 sinfos[0].idxs.resize(n_stream);
2162 for (uint32_t s = 0; s < n_stream; ++s) {
2163 sinfos[0].strm.push_back(s);
2164 sinfos[0].idxs[s].resize(1, 0);
2165 }
2166}
2167
2168llama_kv_cache_context::llama_kv_cache_context(
2169 llama_kv_cache * kv,
2170 llama_context * lctx,
2171 bool do_shift,
2172 stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
2173 if (!do_shift && this->sc_info.empty()) {
2174 status = LLAMA_MEMORY_STATUS_NO_UPDATE;
2175 }
2176}
2177
2178llama_kv_cache_context::llama_kv_cache_context(
2179 llama_kv_cache * kv,
2180 llama_kv_cache::slot_info_vec_t sinfos,
2181 std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
2182}
2183
2184llama_kv_cache_context::~llama_kv_cache_context() = default;
2185
2186bool llama_kv_cache_context::next() {
2187 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2188
2189 if (++i_cur >= ubatches.size()) {
2190 return false;
2191 }
2192
2193 return true;
2194}
2195
2196bool llama_kv_cache_context::apply() {
2197 assert(!llama_memory_status_is_fail(status));
2198
2199 // no ubatches -> this is a KV cache update
2200 if (ubatches.empty()) {
2201 kv->update(lctx, do_shift, sc_info);
2202
2203 return true;
2204 }
2205
2206 kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
2207 n_kv = kv->get_n_kv(sinfos[i_cur]);
2208
2209 return true;
2210}
2211
2212llama_memory_status llama_kv_cache_context::get_status() const {
2213 return status;
2214}
2215
2216const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
2217 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2218
2219 return ubatches[i_cur];
2220}
2221
2222uint32_t llama_kv_cache_context::get_n_kv() const {
2223 return n_kv;
2224}
2225
2226ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
2227 return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
2228}
2229
2230ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
2231 return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
2232}
2233
2234ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
2235 return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
2236}
2237
2238ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
2239 return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
2240}
2241
2242ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2243 return kv->build_input_k_idxs(ctx, ubatch);
2244}
2245
2246ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2247 return kv->build_input_v_idxs(ctx, ubatch);
2248}
2249
2250void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
2251 kv->set_input_k_shift(dst);
2252}
2253
2254void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2255 kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
2256}
2257
2258void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2259 kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
2260}
2261
2262void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
2263 kv->set_input_kq_mask(dst, ubatch, causal_attn);
2264}
2265
2266void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2267 kv->set_input_pos_bucket(dst, ubatch);
2268}