1#include "llama-kv-cache-iswa.h"
2
3#include "llama-impl.h"
4#include "llama-batch.h"
5#include "llama-model.h"
6
7#include <algorithm>
8#include <cassert>
9
10//
11// llama_kv_cache_iswa
12//
13
14llama_kv_cache_iswa::llama_kv_cache_iswa(
15 const llama_model & model,
16 ggml_type type_k,
17 ggml_type type_v,
18 bool v_trans,
19 bool offload,
20 bool swa_full,
21 bool unified,
22 uint32_t kv_size,
23 uint32_t n_seq_max,
24 uint32_t n_ubatch,
25 uint32_t n_pad,
26 const layer_filter_cb & filter,
27 const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
29 // chain filters
30 const layer_filter_cb filter_base = [&](int32_t il) {
31 if (filter && !filter(il)) {
32 return false;
33 }
34
35 return !model.hparams.is_swa(il);
36 };
37
38 const layer_filter_cb filter_swa = [&](int32_t il) {
39 if (filter && !filter(il)) {
40 return false;
41 }
42
43 return model.hparams.is_swa(il);
44 };
45
46 const uint32_t size_base = kv_size;
47
48 // note: the SWA cache is always padded to 256 for performance
49 // https://github.com/ggml-org/llama.cpp/issues/17037
50 uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
51
52 // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
53 if (swa_full) {
54 LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
55 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
56
57 size_swa = size_base;
58 }
59
60 LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
61
62 kv_base = std::make_unique<llama_kv_cache>(
63 model, type_k, type_v,
64 v_trans, offload, unified, size_base, n_seq_max, n_pad,
65 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
66
67 LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
68
69 kv_swa = std::make_unique<llama_kv_cache>(
70 model, type_k, type_v,
71 v_trans, offload, unified, size_swa, n_seq_max, n_pad,
72 hparams.n_swa, hparams.swa_type, filter_swa, reuse);
73}
74
75void llama_kv_cache_iswa::clear(bool data) {
76 kv_base->clear(data);
77 kv_swa ->clear(data);
78}
79
80bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
81 bool res = true;
82
83 res = res & kv_base->seq_rm(seq_id, p0, p1);
84 res = res & kv_swa ->seq_rm(seq_id, p0, p1);
85
86 return res;
87}
88
89void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
90 kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
91 kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
92}
93
94void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
95 kv_base->seq_keep(seq_id);
96 kv_swa ->seq_keep(seq_id);
97}
98
99void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
100 kv_base->seq_add(seq_id, p0, p1, shift);
101 kv_swa ->seq_add(seq_id, p0, p1, shift);
102}
103
104void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
105 kv_base->seq_div(seq_id, p0, p1, d);
106 kv_swa ->seq_div(seq_id, p0, p1, d);
107}
108
109llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
110 // the base cache is a superset of the SWA cache, so we can just check the SWA cache
111 return kv_swa->seq_pos_min(seq_id);
112}
113
114llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
115 return kv_swa->seq_pos_max(seq_id);
116}
117
118std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
119 std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
120 for (const auto & buft_size : kv_swa->memory_breakdown()) {
121 mb[buft_size.first] += buft_size.second;
122 }
123 return mb;
124}
125
126llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
127 GGML_UNUSED(embd_all);
128
129 // first try simple split
130 do {
131 if (!unified) {
132 // requires equal splits, so we skip the simple split
133 break;
134 }
135
136 balloc.split_reset();
137
138 std::vector<llama_ubatch> ubatches;
139 while (true) {
140 auto ubatch = balloc.split_simple(n_ubatch);
141
142 if (ubatch.n_tokens == 0) {
143 break;
144 }
145
146 ubatches.push_back(std::move(ubatch)); // NOLINT
147 }
148
149 if (balloc.get_n_used() < balloc.get_n_tokens()) {
150 // failed to find a suitable split
151 break;
152 }
153
154 auto sinfos_base = kv_base->prepare(ubatches);
155 if (sinfos_base.empty()) {
156 break;
157 }
158
159 auto sinfos_swa = kv_swa->prepare(ubatches);
160 if (sinfos_swa.empty()) {
161 break;
162 }
163
164 assert(sinfos_base.size() == sinfos_swa.size());
165
166 return std::make_unique<llama_kv_cache_iswa_context>(
167 this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
168 } while (false);
169
170 // if it fails, try equal split
171 do {
172 balloc.split_reset();
173
174 std::vector<llama_ubatch> ubatches;
175 while (true) {
176 auto ubatch = balloc.split_equal(n_ubatch, !unified);
177
178 if (ubatch.n_tokens == 0) {
179 break;
180 }
181
182 ubatches.push_back(std::move(ubatch)); // NOLINT
183 }
184
185 if (balloc.get_n_used() < balloc.get_n_tokens()) {
186 // failed to find a suitable split
187 break;
188 }
189
190 auto sinfos_base = kv_base->prepare(ubatches);
191 if (sinfos_base.empty()) {
192 break;
193 }
194
195 auto sinfos_swa = kv_swa->prepare(ubatches);
196 if (sinfos_swa.empty()) {
197 break;
198 }
199
200 assert(sinfos_base.size() == sinfos_swa.size());
201
202 return std::make_unique<llama_kv_cache_iswa_context>(
203 this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
204 } while (false);
205
206 // TODO: if we fail again, we should attempt different splitting strategies
207 // but to do that properly, we first have to refactor the batches to be more flexible
208
209 return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210}
211
212llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
213 return std::make_unique<llama_kv_cache_iswa_context>(this);
214}
215
216llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
217 return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
218}
219
220bool llama_kv_cache_iswa::get_can_shift() const {
221 return kv_base->get_can_shift() &&
222 kv_swa->get_can_shift() &&
223 kv_base->get_size() == kv_swa->get_size();
224}
225
226void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
227 if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
228 kv_base->state_write(io, seq_id, flags);
229 }
230
231 kv_swa->state_write(io, seq_id, flags);
232}
233
234void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
235 if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
236 kv_base->state_read(io, seq_id, flags);
237 }
238
239 kv_swa->state_read(io, seq_id, flags);
240}
241
242llama_kv_cache * llama_kv_cache_iswa::get_base() const {
243 return kv_base.get();
244}
245
246llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
247 return kv_swa.get();
248}
249
250//
251// llama_kv_cache_iswa_context
252//
253
254llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
255
256llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
257 llama_kv_cache_iswa * kv) :
258 ctx_base(kv->get_base()->init_full()),
259 ctx_swa (kv->get_swa ()->init_full()),
260 status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
261}
262
263llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
264 llama_kv_cache_iswa * kv,
265 llama_context * lctx,
266 bool optimize) :
267 ctx_base(kv->get_base()->init_update(lctx, optimize)),
268 ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
269 status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
270}
271
272llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
273 llama_kv_cache_iswa * kv,
274 slot_info_vec_t sinfos_base,
275 slot_info_vec_t sinfos_swa,
276 std::vector<llama_ubatch> ubatches) :
277 ubatches(std::move(ubatches)),
278 // note: here we copy the ubatches. not sure if this is ideal
279 ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
280 ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
281 status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
282}
283
284llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
285
286bool llama_kv_cache_iswa_context::next() {
287 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
288
289 ctx_base->next();
290 ctx_swa ->next();
291
292 if (++i_next >= ubatches.size()) {
293 return false;
294 }
295
296 return true;
297}
298
299bool llama_kv_cache_iswa_context::apply() {
300 assert(!llama_memory_status_is_fail(status));
301
302 bool res = true;
303
304 res = res & ctx_base->apply();
305 res = res & ctx_swa ->apply();
306
307 return res;
308}
309
310llama_memory_status llama_kv_cache_iswa_context::get_status() const {
311 return status;
312}
313
314const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
315 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
316
317 return ubatches[i_next];
318}
319
320const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
321 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
322
323 return static_cast<const llama_kv_cache_context *>(ctx_base.get());
324}
325
326const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
327 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
328
329 return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
330}