1#include "llama-kv-cache-iswa.h"
  2
  3#include "llama-impl.h"
  4#include "llama-batch.h"
  5#include "llama-model.h"
  6
  7#include <algorithm>
  8#include <cassert>
  9
 10//
 11// llama_kv_cache_iswa
 12//
 13
 14llama_kv_cache_iswa::llama_kv_cache_iswa(
 15        const llama_model & model,
 16                ggml_type   type_k,
 17                ggml_type   type_v,
 18                     bool   v_trans,
 19                     bool   offload,
 20                     bool   swa_full,
 21                     bool   unified,
 22                 uint32_t   kv_size,
 23                 uint32_t   n_seq_max,
 24                 uint32_t   n_ubatch,
 25                 uint32_t   n_pad,
 26    const layer_filter_cb & filter,
 27    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
 28
 29    // chain filters
 30    const layer_filter_cb filter_base = [&](int32_t il) {
 31        if (filter && !filter(il)) {
 32            return false;
 33        }
 34
 35        return !model.hparams.is_swa(il);
 36    };
 37
 38    const layer_filter_cb filter_swa  = [&](int32_t il) {
 39        if (filter && !filter(il)) {
 40            return false;
 41        }
 42
 43        return  model.hparams.is_swa(il);
 44    };
 45
 46    const uint32_t size_base = kv_size;
 47
 48    // note: the SWA cache is always padded to 256 for performance
 49    //       https://github.com/ggml-org/llama.cpp/issues/17037
 50    uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
 51
 52    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
 53    if (swa_full) {
 54        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
 55                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
 56
 57        size_swa = size_base;
 58    }
 59
 60    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
 61
 62    kv_base = std::make_unique<llama_kv_cache>(
 63            model, type_k, type_v,
 64            v_trans, offload, unified, size_base, n_seq_max, n_pad,
 65            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
 66
 67    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 68
 69    kv_swa = std::make_unique<llama_kv_cache>(
 70            model, type_k, type_v,
 71            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
 72            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
 73}
 74
 75void llama_kv_cache_iswa::clear(bool data) {
 76    kv_base->clear(data);
 77    kv_swa ->clear(data);
 78}
 79
 80bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 81    bool res = true;
 82
 83    res = res & kv_base->seq_rm(seq_id, p0, p1);
 84    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
 85
 86    return res;
 87}
 88
 89void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
 90    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 91    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 92}
 93
 94void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
 95    kv_base->seq_keep(seq_id);
 96    kv_swa ->seq_keep(seq_id);
 97}
 98
 99void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
100    kv_base->seq_add(seq_id, p0, p1, shift);
101    kv_swa ->seq_add(seq_id, p0, p1, shift);
102}
103
104void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
105    kv_base->seq_div(seq_id, p0, p1, d);
106    kv_swa ->seq_div(seq_id, p0, p1, d);
107}
108
109llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
110    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
111    return kv_swa->seq_pos_min(seq_id);
112}
113
114llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
115    return kv_swa->seq_pos_max(seq_id);
116}
117
118std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
119    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
120    for (const auto & buft_size : kv_swa->memory_breakdown()) {
121        mb[buft_size.first] += buft_size.second;
122    }
123    return mb;
124}
125
126llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
127    GGML_UNUSED(embd_all);
128
129    // first try simple split
130    do {
131        if (!unified) {
132            // requires equal splits, so we skip the simple split
133            break;
134        }
135
136        balloc.split_reset();
137
138        std::vector<llama_ubatch> ubatches;
139        while (true) {
140            auto ubatch = balloc.split_simple(n_ubatch);
141
142            if (ubatch.n_tokens == 0) {
143                break;
144            }
145
146            ubatches.push_back(std::move(ubatch)); // NOLINT
147        }
148
149        if (balloc.get_n_used() < balloc.get_n_tokens()) {
150            // failed to find a suitable split
151            break;
152        }
153
154        auto sinfos_base = kv_base->prepare(ubatches);
155        if (sinfos_base.empty()) {
156            break;
157        }
158
159        auto sinfos_swa = kv_swa->prepare(ubatches);
160        if (sinfos_swa.empty()) {
161            break;
162        }
163
164        assert(sinfos_base.size() == sinfos_swa.size());
165
166        return std::make_unique<llama_kv_cache_iswa_context>(
167                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
168    } while (false);
169
170    // if it fails, try equal split
171    do {
172        balloc.split_reset();
173
174        std::vector<llama_ubatch> ubatches;
175        while (true) {
176            auto ubatch = balloc.split_equal(n_ubatch, !unified);
177
178            if (ubatch.n_tokens == 0) {
179                break;
180            }
181
182            ubatches.push_back(std::move(ubatch)); // NOLINT
183        }
184
185        if (balloc.get_n_used() < balloc.get_n_tokens()) {
186            // failed to find a suitable split
187            break;
188        }
189
190        auto sinfos_base = kv_base->prepare(ubatches);
191        if (sinfos_base.empty()) {
192            break;
193        }
194
195        auto sinfos_swa = kv_swa->prepare(ubatches);
196        if (sinfos_swa.empty()) {
197            break;
198        }
199
200        assert(sinfos_base.size() == sinfos_swa.size());
201
202        return std::make_unique<llama_kv_cache_iswa_context>(
203                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
204    } while (false);
205
206    // TODO: if we fail again, we should attempt different splitting strategies
207    //       but to do that properly, we first have to refactor the batches to be more flexible
208
209    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210}
211
212llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
213    return std::make_unique<llama_kv_cache_iswa_context>(this);
214}
215
216llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
217    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
218}
219
220bool llama_kv_cache_iswa::get_can_shift() const {
221    return kv_base->get_can_shift() &&
222           kv_swa->get_can_shift() &&
223           kv_base->get_size() == kv_swa->get_size();
224}
225
226void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
227    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
228        kv_base->state_write(io, seq_id, flags);
229    }
230
231    kv_swa->state_write(io, seq_id, flags);
232}
233
234void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
235    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
236        kv_base->state_read(io, seq_id, flags);
237    }
238
239    kv_swa->state_read(io, seq_id, flags);
240}
241
242llama_kv_cache * llama_kv_cache_iswa::get_base() const {
243    return kv_base.get();
244}
245
246llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
247    return kv_swa.get();
248}
249
250//
251// llama_kv_cache_iswa_context
252//
253
254llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
255
256llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
257        llama_kv_cache_iswa * kv) :
258    ctx_base(kv->get_base()->init_full()),
259    ctx_swa (kv->get_swa ()->init_full()),
260    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
261}
262
263llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
264        llama_kv_cache_iswa * kv,
265        llama_context * lctx,
266        bool optimize) :
267    ctx_base(kv->get_base()->init_update(lctx, optimize)),
268    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
269    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
270}
271
272llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
273        llama_kv_cache_iswa * kv,
274        slot_info_vec_t sinfos_base,
275        slot_info_vec_t sinfos_swa,
276        std::vector<llama_ubatch> ubatches) :
277    ubatches(std::move(ubatches)),
278    // note: here we copy the ubatches. not sure if this is ideal
279    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
280    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
281    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
282}
283
284llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
285
286bool llama_kv_cache_iswa_context::next() {
287    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
288
289    ctx_base->next();
290    ctx_swa ->next();
291
292    if (++i_next >= ubatches.size()) {
293        return false;
294    }
295
296    return true;
297}
298
299bool llama_kv_cache_iswa_context::apply() {
300    assert(!llama_memory_status_is_fail(status));
301
302    bool res = true;
303
304    res = res & ctx_base->apply();
305    res = res & ctx_swa ->apply();
306
307    return res;
308}
309
310llama_memory_status llama_kv_cache_iswa_context::get_status() const {
311    return status;
312}
313
314const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
315    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
316
317    return ubatches[i_next];
318}
319
320const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
321    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
322
323    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
324}
325
326const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
327    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
328
329    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
330}