summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-memory.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/src/llama-memory.h')
-rw-r--r--llama.cpp/src/llama-memory.h122
1 files changed, 122 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-memory.h b/llama.cpp/src/llama-memory.h
new file mode 100644
index 0000000..4a157b9
--- /dev/null
+++ b/llama.cpp/src/llama-memory.h
@@ -0,0 +1,122 @@
+#pragma once
+
+#include "llama.h"
+
+#include <map>
+#include <memory>
+#include <functional>
+
+struct llama_ubatch;
+
+class llama_batch_allocr;
+
+class llama_io_write_i;
+class llama_io_read_i;
+
+struct llama_memory_params {
+ // kv cache
+ ggml_type type_k;
+ ggml_type type_v;
+
+ // use full-size SWA cache
+ bool swa_full;
+};
+
+enum llama_memory_status {
+ LLAMA_MEMORY_STATUS_SUCCESS = 0,
+ LLAMA_MEMORY_STATUS_NO_UPDATE,
+ LLAMA_MEMORY_STATUS_FAILED_PREPARE,
+ LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
+};
+
+// helper function for combining the status of two memory contexts
+// useful for implementing hybrid memory types (e.g. iSWA)
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
+
+// helper function for checking if a memory status indicates a failure
+bool llama_memory_status_is_fail(llama_memory_status status);
+
+// the interface for managing the memory context during batch processing
+// this interface is implemented per memory type. see:
+// - llama_kv_cache_context
+// - llama_kv_cache_iswa_context
+// ...
+//
+// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
+struct llama_memory_context_i {
+ virtual ~llama_memory_context_i() = default;
+
+ // consume the current ubatch from the context and proceed to the next one
+ // return false if we are done
+ virtual bool next() = 0;
+
+ // apply the memory state for the current ubatch to the memory object
+ // return false on failure
+ virtual bool apply() = 0;
+
+ // get the current ubatch
+ virtual const llama_ubatch & get_ubatch() const = 0;
+
+ // get the status of the memory context - used for error handling and checking if any updates would be applied
+ virtual llama_memory_status get_status() const = 0;
+};
+
+using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+struct llama_memory_i {
+ // this callback is used to filter out layers that should not be included in the cache
+ using layer_filter_cb = std::function<bool(int32_t il)>;
+
+ // this callback is used to specify which layers should reuse memory from other layers
+ // return negative value to indicate that the layer il should not reuse memory
+ using layer_reuse_cb = std::function<int32_t(int32_t il)>;
+
+ virtual ~llama_memory_i() = default;
+
+ // split the input batch into a set of ubatches and verify that they can fit into the cache
+ // return a context object containing the ubatches and memory state required to process them
+ // check the llama_memory_context_i::get_status() for the result
+ virtual llama_memory_context_ptr init_batch(
+ llama_batch_allocr & balloc,
+ uint32_t n_ubatch,
+ bool embd_all) = 0;
+
+ // simulate full cache, used for allocating worst-case compute buffers
+ virtual llama_memory_context_ptr init_full() = 0;
+
+ // prepare for any pending memory updates, such as shifts, copies, etc.
+ // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+ virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
+
+ // getters
+ virtual bool get_can_shift() const = 0;
+
+ //
+ // ops
+ //
+
+ // if data == true, the data buffers will also be cleared together with the metadata
+ virtual void clear(bool data) = 0;
+
+ virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
+ virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+ virtual void seq_keep(llama_seq_id seq_id) = 0;
+ virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
+ virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
+
+ virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
+ virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
+
+ virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
+
+ //
+ // state write/read
+ //
+
+ virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
+ virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
+};
+
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;