1#pragma once
  2
  3#include "llama.h"
  4
  5#include <map>
  6#include <memory>
  7#include <functional>
  8
  9struct llama_ubatch;
 10
 11class llama_batch_allocr;
 12
 13class llama_io_write_i;
 14class llama_io_read_i;
 15
 16struct llama_memory_params {
 17    // kv cache
 18    ggml_type type_k;
 19    ggml_type type_v;
 20
 21    // use full-size SWA cache
 22    bool swa_full;
 23};
 24
 25enum llama_memory_status {
 26    LLAMA_MEMORY_STATUS_SUCCESS = 0,
 27    LLAMA_MEMORY_STATUS_NO_UPDATE,
 28    LLAMA_MEMORY_STATUS_FAILED_PREPARE,
 29    LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
 30};
 31
 32// helper function for combining the status of two memory contexts
 33// useful for implementing hybrid memory types (e.g. iSWA)
 34llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
 35
 36// helper function for checking if a memory status indicates a failure
 37bool llama_memory_status_is_fail(llama_memory_status status);
 38
 39// the interface for managing the memory context during batch processing
 40// this interface is implemented per memory type. see:
 41//   - llama_kv_cache_context
 42//   - llama_kv_cache_iswa_context
 43//   ...
 44//
 45// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
 46struct llama_memory_context_i {
 47    virtual ~llama_memory_context_i() = default;
 48
 49    // consume the current ubatch from the context and proceed to the next one
 50    // return false if we are done
 51    virtual bool next() = 0;
 52
 53    // apply the memory state for the current ubatch to the memory object
 54    // return false on failure
 55    virtual bool apply() = 0;
 56
 57    // get the current ubatch
 58    virtual const llama_ubatch & get_ubatch() const = 0;
 59
 60    // get the status of the memory context - used for error handling and checking if any updates would be applied
 61    virtual llama_memory_status get_status() const = 0;
 62};
 63
 64using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
 65
 66// general concept of LLM memory
 67// the KV cache is a type of LLM memory, but there can be other types
 68struct llama_memory_i {
 69    // this callback is used to filter out layers that should not be included in the cache
 70    using layer_filter_cb = std::function<bool(int32_t il)>;
 71
 72    // this callback is used to specify which layers should reuse memory from other layers
 73    // return negative value to indicate that the layer il should not reuse memory
 74    using layer_reuse_cb = std::function<int32_t(int32_t il)>;
 75
 76    virtual ~llama_memory_i() = default;
 77
 78    // split the input batch into a set of ubatches and verify that they can fit into the cache
 79    // return a context object containing the ubatches and memory state required to process them
 80    // check the llama_memory_context_i::get_status() for the result
 81    virtual llama_memory_context_ptr init_batch(
 82            llama_batch_allocr & balloc,
 83            uint32_t n_ubatch,
 84            bool embd_all) = 0;
 85
 86    // simulate full cache, used for allocating worst-case compute buffers
 87    virtual llama_memory_context_ptr init_full() = 0;
 88
 89    // prepare for any pending memory updates, such as shifts, copies, etc.
 90    // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
 91    virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
 92
 93    // getters
 94    virtual bool get_can_shift() const = 0;
 95
 96    //
 97    // ops
 98    //
 99
100    // if data == true, the data buffers will also be cleared together with the metadata
101    virtual void clear(bool data) = 0;
102
103    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
104    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
105    virtual void seq_keep(llama_seq_id seq_id) = 0;
106    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
107    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
108
109    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
110    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
111
112    virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
113
114    //
115    // state write/read
116    //
117
118    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
119    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
120};
121
122using llama_memory_ptr = std::unique_ptr<llama_memory_i>;