1#pragma once
  2
  3#include "llama.h"
  4#include "llama-cparams.h"
  5
  6#include <bitset>
  7#include <cassert>
  8#include <cstring>
  9#include <map>
 10#include <set>
 11#include <vector>
 12
 13struct llama_kv_cell_ext {
 14    // 2D spatial positions, typically used for M-RoPE
 15    llama_pos x = 0;
 16    llama_pos y = 0;
 17
 18    // return true if the current 2D spatial position is greater than other
 19    bool is_2d_gt(llama_pos ox, llama_pos oy) const {
 20        return (y > oy) || (y == oy && x > ox);
 21    }
 22
 23    void reset() {
 24        static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
 25
 26        memset(this, 0, sizeof(*this));
 27    }
 28};
 29
 30// meta information about KV cells that can be part of multiple sequences at the same time
 31// TODO: add unit tests
 32class llama_kv_cells {
 33public:
 34    void reset() {
 35        for (uint32_t i = 0; i < pos.size(); ++i) {
 36            pos[i]   = -1;
 37            ext[i].reset();
 38            shift[i] =  0;
 39            seq[i].reset();
 40        }
 41
 42        has_shift = false;
 43
 44        used.clear();
 45
 46        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
 47            seq_pos[s].clear();
 48        }
 49    }
 50
 51    void reset_shift() {
 52        has_shift = false;
 53
 54        for (uint32_t i = 0; i < shift.size(); ++i) {
 55            shift[i] = 0;
 56        }
 57    }
 58
 59    uint32_t size() const {
 60        return pos.size();
 61    }
 62
 63    void resize(uint32_t n) {
 64        pos.resize(n);
 65        ext.resize(n);
 66        shift.resize(n);
 67        seq.resize(n);
 68
 69        reset();
 70    }
 71
 72    bool is_empty(uint32_t i) const {
 73        assert(i < pos.size());
 74        assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
 75
 76        return pos[i] == -1;
 77    }
 78
 79    uint32_t get_used() const {
 80        return used.size();
 81    }
 82
 83    // the index of the first cell that is used
 84    // return 0 if no cells are used
 85    uint32_t used_min() const {
 86        return used.empty() ? 0 : *used.begin();
 87    }
 88
 89    // the index of the last cell that is used + 1
 90    // return 0 if no cells are used
 91    uint32_t used_max_p1() const {
 92        return used.empty() ? 0 : *used.rbegin() + 1;
 93    }
 94
 95    bool get_has_shift() const {
 96        return has_shift;
 97    }
 98
 99    // move cell isrc to idst (used during defrag)
100    //void mv(uint32_t isrc, uint32_t idst) {
101    //    assert(isrc < pos.size());
102    //    assert(idst < pos.size());
103
104    //    assert(pos[idst] == -1);
105    //    assert(pos[isrc] != -1);
106
107    //    pos  [idst] = pos  [isrc];
108    //    shift[idst] = shift[isrc];
109    //    seq  [idst] = seq  [isrc];
110
111    //    pos  [isrc] = -1;
112    //    shift[isrc] =  0;
113    //    seq  [isrc].reset();
114
115    //    used.erase (isrc);
116    //    used.insert(idst);
117    //}
118
119    // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
120    llama_kv_cells cp(uint32_t i, uint32_t n) const {
121        assert(i + n <= pos.size());
122
123        llama_kv_cells res;
124
125        res.resize(n);
126
127        for (uint32_t j = 0; j < n; ++j) {
128            const auto idx = i + j;
129
130            res.pos[j] = pos[idx];
131            res.ext[j] = ext[idx];
132            res.seq[j] = seq[idx];
133
134            assert(shift[idx] == 0);
135        }
136
137        return res;
138    }
139
140    // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
141    llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
142        llama_kv_cells res;
143
144        res.resize(idxs.size());
145
146        for (uint32_t j = 0; j < idxs.size(); ++j) {
147            const auto idx = idxs[j];
148
149            res.pos[j] = pos[idx];
150            res.ext[j] = ext[idx];
151            res.seq[j] = seq[idx];
152
153            assert(shift[idx] == 0);
154        }
155
156        return res;
157    }
158
159    // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
160    void set(uint32_t i, const llama_kv_cells & other) {
161        assert(i + other.pos.size() <= pos.size());
162
163        for (uint32_t j = 0; j < other.pos.size(); ++j) {
164            const auto idx = i + j;
165
166            if (pos[idx] == -1 && other.pos[j] != -1) {
167                used.insert(i + j);
168            }
169
170            if (pos[idx] != -1 && other.pos[j] == -1) {
171                used.erase(i + j);
172            }
173
174            if (pos[idx] != -1) {
175                seq_pos_rm(i + j);
176            }
177
178            pos[idx] = other.pos[j];
179            ext[idx] = other.ext[j];
180            seq[idx] = other.seq[j];
181
182            if (pos[idx] != -1) {
183                seq_pos_add(i + j);
184            }
185
186            assert(shift[idx] == 0);
187        }
188    }
189
190    // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
191    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
192        assert(idxs.size() == other.pos.size());
193
194        for (uint32_t j = 0; j < other.pos.size(); ++j) {
195            const auto idx = idxs[j];
196
197            if (pos[idx] == -1 && other.pos[j] != -1) {
198                used.insert(idx);
199            }
200
201            if (pos[idx] != -1 && other.pos[j] == -1) {
202                used.erase(idx);
203            }
204
205            if (pos[idx] != -1) {
206                seq_pos_rm(idx);
207            }
208
209            pos[idx] = other.pos[j];
210            ext[idx] = other.ext[j];
211            seq[idx] = other.seq[j];
212
213            if (pos[idx] != -1) {
214                seq_pos_add(idx);
215            }
216
217            assert(shift[idx] == 0);
218        }
219    }
220
221    // clear a non-empty cell
222    void rm(uint32_t i) {
223        assert(i < pos.size());
224        assert(pos[i] != -1);
225
226        seq_pos_rm(i);
227        seq[i].reset();
228
229        pos[i] = -1;
230        ext[i].reset();
231        shift[i] = 0;
232
233        used.erase(i);
234    }
235
236    // note: call only if the cell has seq_id
237    // return true if the cell becomes empty
238    bool seq_rm(uint32_t i, llama_seq_id seq_id) {
239        assert(i < pos.size());
240        assert(seq[i].test(seq_id));
241        assert(pos[i] != -1);
242        assert(seq_id >= 0);
243
244        seq[i].reset(seq_id);
245        seq_pos_dec(seq_id, pos[i]);
246
247        if (seq[i].none()) {
248            pos[i] = -1;
249            ext[i].reset();
250            shift[i] = 0;
251
252            used.erase(i);
253
254            return true;
255        }
256
257        return false;
258    }
259
260    // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
261    bool seq_keep(uint32_t i, llama_seq_id seq_id) {
262        assert(i < pos.size());
263
264        if (seq[i].test(seq_id)) {
265            seq_pos_rm(i);
266            seq[i].reset();
267
268            seq[i].set(seq_id);
269            seq_pos_inc(seq_id, pos[i]);
270
271            return false;
272        }
273
274        if (seq[i].any()) {
275            seq_pos_rm(i);
276            seq[i].reset();
277
278            pos[i] = -1;
279            ext[i].reset();
280            shift[i] = 0;
281
282            used.erase(i);
283
284            return true;
285        }
286
287        assert(pos[i] == -1);
288
289        return false;
290    }
291
292    // number of different sequences in the cell
293    int seq_count(uint32_t i) const {
294        assert(i < pos.size());
295        assert(pos[i] != -1);
296
297        return seq[i].count();
298    }
299
300    // check if the cell contains seq_id
301    bool seq_has(uint32_t i, llama_seq_id seq_id) const {
302        assert(i < pos.size());
303        assert(seq_id >= 0);
304
305        return seq[i].test(seq_id);
306    }
307
308    // note: call only if the cell is not empty and the seq_id is not in the cell
309    void seq_add(uint32_t i, llama_seq_id seq_id) {
310        assert(i < pos.size());
311        assert(pos[i] != -1);
312        assert(!seq[i].test(seq_id));
313
314        seq[i].set(seq_id);
315        seq_pos_inc(seq_id, pos[i]);
316    }
317
318    // return the sequence id of this cell
319    // note: call only for cells with exactly one sequence
320    llama_seq_id seq_get(uint32_t i) const {
321        assert(seq[i].count() == 1);
322
323        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
324            if (seq[i].test(s)) {
325                return s;
326            }
327        }
328
329        return -1;
330    }
331
332    // the minimum position of sequence seq_id currently present in any of the cells
333    // return -1 if the sequence is not present
334    llama_pos seq_pos_min(llama_seq_id seq_id) const {
335        assert(seq_id >= 0);
336        assert(seq_id < LLAMA_MAX_SEQ);
337
338        if (seq_pos[seq_id].empty()) {
339            return -1;
340        }
341
342        assert(seq_pos[seq_id].begin()->second > 0);
343
344        return seq_pos[seq_id].begin()->first;
345    }
346
347    // the maximum position of sequence seq_id currently present in any of the cells
348    // return -1 if the sequence is not present
349    llama_pos seq_pos_max(llama_seq_id seq_id) const {
350        assert(seq_id >= 0);
351        assert(seq_id < LLAMA_MAX_SEQ);
352
353        if (seq_pos[seq_id].empty()) {
354            return -1;
355        }
356
357        assert(seq_pos[seq_id].rbegin()->second > 0);
358
359        return seq_pos[seq_id].rbegin()->first;
360    }
361
362    // note: call only if the cell is not empty
363    llama_pos pos_get(uint32_t i) const {
364        assert(i < pos.size());
365        assert(pos[i] != -1);
366
367        return pos[i];
368    }
369
370    const llama_kv_cell_ext & ext_get(uint32_t i) const {
371        assert(i < pos.size());
372        assert(pos[i] != -1);
373
374        return ext[i];
375    }
376
377    // note: call only if the cell is not empty
378    llama_pos get_shift(uint32_t i) const {
379        assert(i < pos.size());
380        assert(pos[i] != -1);
381
382        return shift[i];
383    }
384
385    // check if a cell is not empty and its position is within [p0, p1)
386    bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
387        assert(i < pos.size());
388
389        return pos[i] >= p0 && pos[i] < p1;
390    }
391
392    // set the position of an empty cell
393    // does not modify "has_shift"
394    // note: call only if the cell is empty
395    void pos_set(uint32_t i, llama_pos p) {
396        assert(i < pos.size());
397        assert(pos[i] == -1);
398        assert(seq[i].none());
399
400        pos[i] = p;
401
402        used.insert(i);
403    }
404
405    void ext_set(uint32_t i, llama_kv_cell_ext p) {
406        assert(i < ext.size());
407        ext[i] = p;
408    }
409
410    // pos[i] = pos[i] + d
411    // sets "has_shift" to true
412    // note: call only if the cell is not empty
413    bool pos_add(uint32_t i, llama_pos d) {
414        assert(i < pos.size());
415        assert(pos[i] != -1);
416
417        seq_pos_rm(i);
418
419        pos[i]   += d;
420        shift[i] += d;
421
422        has_shift = true;
423
424        if (pos[i] < 0) {
425            seq[i].reset();
426            pos[i] = -1;
427            shift[i] = 0;
428
429            used.erase(i);
430
431            return true;
432        }
433
434        seq_pos_add(i);
435
436        return false;
437    }
438
439    // pos[i] = pos[i] / d
440    // sets "has_shift" to true
441    // note: call only if the cell is not empty
442    void pos_div(uint32_t i, int d) {
443        assert(i < pos.size());
444        assert(pos[i] != -1);
445
446        const llama_pos p_old = pos[i];
447
448        seq_pos_rm(i);
449
450        pos[i]   /= d;
451        shift[i] += p_old - pos[i];
452
453        seq_pos_add(i);
454
455        has_shift = true;
456    }
457
458private:
459    bool has_shift = false;
460
461    // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
462    std::set<uint32_t> used;
463
464    std::vector<llama_pos> pos;
465
466    // stores extra info per cell
467    std::vector<llama_kv_cell_ext> ext;
468
469    // this array accumulates any applied shifts to the pos array since the last reset_shift() call
470    // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
471    //
472    //   cells.pos_add(x, shift_x);
473    //   cells.pos_div(y, shift_y);
474    //   ...
475    //
476    //   if (cells.has_shift()) {
477    //      for (int i = 0; i < n; ++i) {
478    //          auto shift_i = cells.get_shift(i);
479    //          ...
480    //      }
481    //      cells.reset_shift();
482    //   }
483    //
484    std::vector<llama_pos> shift;
485
486    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
487
488    // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
489    std::vector<seq_set_t> seq;
490
491    // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
492    // if the position p is not present, seq_pos[s][p] is not set
493    // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
494    //
495    // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
496    //  - during performing a cache reuse via (rm + add)
497    //  - some vision models have input embeddings with repeating positions
498    //
499    std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
500
501    // helper functions for updating `seq_pos`, once cell at a time:
502
503    void seq_pos_dec(llama_seq_id s, llama_pos p) {
504        auto it = seq_pos[s].find(p);
505        assert(it != seq_pos[s].end());
506
507        if (--it->second == 0) {
508            seq_pos[s].erase(it);
509        }
510    }
511
512    void seq_pos_inc(llama_seq_id s, llama_pos p) {
513        seq_pos[s][p]++;
514    }
515
516    // remove cell i
517    void seq_pos_rm(uint32_t i) {
518        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
519            if (seq[i].test(s)) {
520                seq_pos_dec(s, pos[i]);
521            }
522        }
523    }
524
525    // add cell i
526    void seq_pos_add(uint32_t i) {
527        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
528            if (seq[i].test(s)) {
529                seq_pos_inc(s, pos[i]);
530            }
531        }
532    }
533};