1#pragma once
  2
  3#include "llama-batch.h"
  4#include "llama-graph.h"
  5#include "llama-memory.h"
  6
  7#include <map>
  8#include <set>
  9#include <vector>
 10
 11//
 12// llama_memory_recurrent
 13//
 14
 15// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
 16//       see the implementation of llama_kv_cache_context_i for an example how to do it
 17class llama_memory_recurrent : public llama_memory_i {
 18public:
 19    llama_memory_recurrent(
 20            const llama_model & model,
 21                    ggml_type   type_r,
 22                    ggml_type   type_s,
 23                         bool   offload,
 24                     uint32_t   mem_size,
 25                     uint32_t   n_seq_max,
 26        const layer_filter_cb & filter);
 27
 28    ~llama_memory_recurrent() = default;
 29
 30    //
 31    // llama_memory_i
 32    //
 33
 34    llama_memory_context_ptr init_batch(
 35            llama_batch_allocr & balloc,
 36            uint32_t n_ubatch,
 37            bool embd_all) override;
 38
 39    llama_memory_context_ptr init_full() override;
 40
 41    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 42
 43    void clear(bool data) override;
 44
 45    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
 46    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
 47    void seq_keep(llama_seq_id seq_id)                                                          override;
 48    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
 49    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 50
 51    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
 52    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 53
 54    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
 55
 56    bool prepare(const std::vector<llama_ubatch> & ubatches);
 57
 58    // find a contiguous slot of memory cells and emplace the ubatch there
 59    bool find_slot(const llama_ubatch & ubatch);
 60
 61    bool get_can_shift() const override;
 62
 63    // state write/load
 64
 65    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
 66    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 67
 68    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
 69    uint32_t size = 0; // total number of cells, shared across all sequences
 70    uint32_t used = 0; // used cells (i.e. at least one seq_id)
 71
 72    // computed before each graph build
 73    uint32_t n = 0;
 74
 75    // first zero-ed state
 76    int32_t rs_z = -1;
 77
 78    // TODO: optimize for recurrent state needs
 79    struct mem_cell {
 80        llama_pos pos  = -1;
 81        int32_t   src  = -1; // used to know where states should be copied from
 82        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
 83        int32_t   tail = -1;
 84
 85        std::set<llama_seq_id> seq_id;
 86
 87        bool has_seq_id(const llama_seq_id & id) const {
 88            return seq_id.find(id) != seq_id.end();
 89        }
 90
 91        bool is_empty() const {
 92            return seq_id.empty();
 93        }
 94
 95        bool is_same_seq(const mem_cell & other) const {
 96            return seq_id == other.seq_id;
 97        }
 98    };
 99
100    std::vector<mem_cell> cells;
101
102    // per layer
103    std::vector<ggml_tensor *> r_l;
104    std::vector<ggml_tensor *> s_l;
105
106private:
107    //const llama_model & model;
108    const llama_hparams & hparams;
109
110    const uint32_t n_seq_max = 1;
111
112    // ggml contexts for the KV cache along with the allocated backend buffers:
113    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114
115    size_t total_size() const;
116
117    size_t size_r_bytes() const;
118    size_t size_s_bytes() const;
119
120    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
121    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
122
123    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
124    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
125};
126
127class llama_memory_recurrent_context : public llama_memory_context_i {
128public:
129    // used for errors
130    llama_memory_recurrent_context(llama_memory_status status);
131
132    // used to create a full-cache or update context
133    llama_memory_recurrent_context(
134            llama_memory_recurrent * mem);
135
136    // used to create a batch processing context from a batch
137    llama_memory_recurrent_context(
138            llama_memory_recurrent * mem,
139            std::vector<llama_ubatch> ubatches);
140
141    virtual ~llama_memory_recurrent_context();
142
143    //
144    // llama_memory_context_i
145    //
146
147    bool next()  override;
148    bool apply() override;
149
150    llama_memory_status  get_status() const override;
151    const llama_ubatch & get_ubatch() const override;
152
153    //
154    // llama_memory_recurrent_context specific API
155    //
156
157    uint32_t get_n_rs() const;
158    uint32_t get_head() const;
159    int32_t  get_rs_z() const;
160    uint32_t get_size() const;
161
162    ggml_tensor * get_r_l(int32_t il) const;
163    ggml_tensor * get_s_l(int32_t il) const;
164
165    int32_t s_copy(int i) const;
166
167private:
168    const llama_memory_status status;
169
170    llama_memory_recurrent * mem;
171
172    size_t i_next = 0;
173
174    std::vector<llama_ubatch> ubatches;
175
176    //
177    // data needed for building the compute graph for the current ubatch:
178    // TODO: extract all the state like `head` and `n` here
179    //
180
181    const bool is_full = false;
182};