summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-memory-recurrent.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/src/llama-memory-recurrent.h')
-rw-r--r--llama.cpp/src/llama-memory-recurrent.h182
1 files changed, 182 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-memory-recurrent.h b/llama.cpp/src/llama-memory-recurrent.h
new file mode 100644
index 0000000..47f01d7
--- /dev/null
+++ b/llama.cpp/src/llama-memory-recurrent.h
@@ -0,0 +1,182 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-memory.h"
+
+#include <map>
+#include <set>
+#include <vector>
+
+//
+// llama_memory_recurrent
+//
+
+// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
+// see the implementation of llama_kv_cache_context_i for an example how to do it
+class llama_memory_recurrent : public llama_memory_i {
+public:
+ llama_memory_recurrent(
+ const llama_model & model,
+ ggml_type type_r,
+ ggml_type type_s,
+ bool offload,
+ uint32_t mem_size,
+ uint32_t n_seq_max,
+ const layer_filter_cb & filter);
+
+ ~llama_memory_recurrent() = default;
+
+ //
+ // llama_memory_i
+ //
+
+ llama_memory_context_ptr init_batch(
+ llama_batch_allocr & balloc,
+ uint32_t n_ubatch,
+ bool embd_all) override;
+
+ llama_memory_context_ptr init_full() override;
+
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+ void clear(bool data) override;
+
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+ void seq_keep(llama_seq_id seq_id) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
+
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+ std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
+ bool prepare(const std::vector<llama_ubatch> & ubatches);
+
+ // find a contiguous slot of memory cells and emplace the ubatch there
+ bool find_slot(const llama_ubatch & ubatch);
+
+ bool get_can_shift() const override;
+
+ // state write/load
+
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
+
+ uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
+ uint32_t size = 0; // total number of cells, shared across all sequences
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+ // computed before each graph build
+ uint32_t n = 0;
+
+ // first zero-ed state
+ int32_t rs_z = -1;
+
+ // TODO: optimize for recurrent state needs
+ struct mem_cell {
+ llama_pos pos = -1;
+ int32_t src = -1; // used to know where states should be copied from
+ int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
+ int32_t tail = -1;
+
+ std::set<llama_seq_id> seq_id;
+
+ bool has_seq_id(const llama_seq_id & id) const {
+ return seq_id.find(id) != seq_id.end();
+ }
+
+ bool is_empty() const {
+ return seq_id.empty();
+ }
+
+ bool is_same_seq(const mem_cell & other) const {
+ return seq_id == other.seq_id;
+ }
+ };
+
+ std::vector<mem_cell> cells;
+
+ // per layer
+ std::vector<ggml_tensor *> r_l;
+ std::vector<ggml_tensor *> s_l;
+
+private:
+ //const llama_model & model;
+ const llama_hparams & hparams;
+
+ const uint32_t n_seq_max = 1;
+
+ // ggml contexts for the KV cache along with the allocated backend buffers:
+ std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
+
+ size_t total_size() const;
+
+ size_t size_r_bytes() const;
+ size_t size_s_bytes() const;
+
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
+
+class llama_memory_recurrent_context : public llama_memory_context_i {
+public:
+ // used for errors
+ llama_memory_recurrent_context(llama_memory_status status);
+
+ // used to create a full-cache or update context
+ llama_memory_recurrent_context(
+ llama_memory_recurrent * mem);
+
+ // used to create a batch processing context from a batch
+ llama_memory_recurrent_context(
+ llama_memory_recurrent * mem,
+ std::vector<llama_ubatch> ubatches);
+
+ virtual ~llama_memory_recurrent_context();
+
+ //
+ // llama_memory_context_i
+ //
+
+ bool next() override;
+ bool apply() override;
+
+ llama_memory_status get_status() const override;
+ const llama_ubatch & get_ubatch() const override;
+
+ //
+ // llama_memory_recurrent_context specific API
+ //
+
+ uint32_t get_n_rs() const;
+ uint32_t get_head() const;
+ int32_t get_rs_z() const;
+ uint32_t get_size() const;
+
+ ggml_tensor * get_r_l(int32_t il) const;
+ ggml_tensor * get_s_l(int32_t il) const;
+
+ int32_t s_copy(int i) const;
+
+private:
+ const llama_memory_status status;
+
+ llama_memory_recurrent * mem;
+
+ size_t i_next = 0;
+
+ std::vector<llama_ubatch> ubatches;
+
+ //
+ // data needed for building the compute graph for the current ubatch:
+ // TODO: extract all the state like `head` and `n` here
+ //
+
+ const bool is_full = false;
+};