diff options
Diffstat (limited to 'llama.cpp/src/llama-memory-recurrent.h')
| -rw-r--r-- | llama.cpp/src/llama-memory-recurrent.h | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-memory-recurrent.h b/llama.cpp/src/llama-memory-recurrent.h new file mode 100644 index 0000000..47f01d7 --- /dev/null +++ b/llama.cpp/src/llama-memory-recurrent.h @@ -0,0 +1,182 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-memory.h" + +#include <map> +#include <set> +#include <vector> + +// +// llama_memory_recurrent +// + +// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i +// see the implementation of llama_kv_cache_context_i for an example how to do it +class llama_memory_recurrent : public llama_memory_i { +public: + llama_memory_recurrent( + const llama_model & model, + ggml_type type_r, + ggml_type type_s, + bool offload, + uint32_t mem_size, + uint32_t n_seq_max, + const layer_filter_cb & filter); + + ~llama_memory_recurrent() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + bool prepare(const std::vector<llama_ubatch> & ubatches); + + // find a contiguous slot of memory cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // first zero-ed state + int32_t rs_z = -1; + + // TODO: optimize for recurrent state needs + struct mem_cell { + llama_pos pos = -1; + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) + int32_t tail = -1; + + std::set<llama_seq_id> seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const mem_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector<mem_cell> cells; + + // per layer + std::vector<ggml_tensor *> r_l; + std::vector<ggml_tensor *> s_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; + + size_t total_size() const; + + size_t size_r_bytes() const; + size_t size_s_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_memory_recurrent_context : public llama_memory_context_i { +public: + // used for errors + llama_memory_recurrent_context(llama_memory_status status); + + // used to create a full-cache or update context + llama_memory_recurrent_context( + llama_memory_recurrent * mem); + + // used to create a batch processing context from a batch + llama_memory_recurrent_context( + llama_memory_recurrent * mem, + std::vector<llama_ubatch> ubatches); + + virtual ~llama_memory_recurrent_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_recurrent_context specific API + // + + uint32_t get_n_rs() const; + uint32_t get_head() const; + int32_t get_rs_z() const; + uint32_t get_size() const; + + ggml_tensor * get_r_l(int32_t il) const; + ggml_tensor * get_s_l(int32_t il) const; + + int32_t s_copy(int i) const; + +private: + const llama_memory_status status; + + llama_memory_recurrent * mem; + + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; |
