1#pragma once
  2
  3#include "llama-batch.h"
  4#include "llama-graph.h"
  5#include "llama-kv-cache.h"
  6#include "llama-memory.h"
  7#include "llama-memory-recurrent.h"
  8
  9#include <memory>
 10#include <vector>
 11
 12//
 13// llama_memory_hybrid
 14//
 15
 16// utilizes instances of llama_memory_recurrent and llama_kv_cache to
 17//   support models where each layer may be either attention-based or recurrent
 18
 19class llama_memory_hybrid : public llama_memory_i {
 20public:
 21    llama_memory_hybrid(
 22        const llama_model & model,
 23                            /* attn */
 24                ggml_type   type_k,
 25                ggml_type   type_v,
 26                     bool   v_trans,
 27                 uint32_t   kv_size,
 28                 uint32_t   n_pad,
 29                 uint32_t   n_swa,
 30           llama_swa_type   swa_type,
 31                            /* recurrent */
 32                ggml_type   type_r,
 33                ggml_type   type_s,
 34                 uint32_t   rs_size,
 35                            /* common */
 36                 uint32_t   n_seq_max,
 37                     bool   offload,
 38                     bool   unified,
 39                            /* layer filters */
 40    const layer_filter_cb & filter_attn = nullptr,
 41    const layer_filter_cb & filter_recr = nullptr);
 42
 43    ~llama_memory_hybrid() = default;
 44
 45    //
 46    // llama_memory_i
 47    //
 48
 49    llama_memory_context_ptr init_batch(
 50            llama_batch_allocr & balloc,
 51            uint32_t n_ubatch,
 52            bool embd_all) override;
 53
 54    llama_memory_context_ptr init_full() override;
 55
 56    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 57
 58    bool get_can_shift() const override;
 59
 60    void clear(bool data) override;
 61
 62    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
 63    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
 64    void seq_keep(llama_seq_id seq_id)                                                          override;
 65    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
 66    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 67
 68    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
 69    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 70
 71    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
 72
 73    // state write/load
 74
 75    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
 76    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0)       override;
 77
 78    //
 79    // llama_memory_hybrid specific API
 80    //
 81
 82    llama_kv_cache * get_mem_attn() const;
 83    llama_memory_recurrent * get_mem_recr() const;
 84
 85private:
 86    const llama_hparams & hparams;
 87
 88    const std::unique_ptr<llama_kv_cache> mem_attn;
 89    const std::unique_ptr<llama_memory_recurrent> mem_recr;
 90};
 91
 92class llama_memory_hybrid_context : public llama_memory_context_i {
 93public:
 94    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 95
 96    // init failure
 97    explicit llama_memory_hybrid_context(llama_memory_status status);
 98
 99    // init full
100    explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
101
102    // init update
103    explicit llama_memory_hybrid_context(
104        llama_memory_hybrid * mem,
105              llama_context * lctx,
106                       bool   optimize);
107
108    // init success
109    llama_memory_hybrid_context(
110              llama_memory_hybrid * mem,
111                  slot_info_vec_t   sinfos_attn,
112        std::vector<llama_ubatch>   ubatches);
113
114    ~llama_memory_hybrid_context() = default;
115
116    bool next()  override;
117    bool apply() override;
118
119    llama_memory_status  get_status() const override;
120    const llama_ubatch & get_ubatch() const override;
121
122    //
123    // llama_memory_hybrid_context
124    //
125
126    const llama_kv_cache_context * get_attn() const;
127    const llama_memory_recurrent_context * get_recr() const;
128
129private:
130    // the index of the next ubatch to process
131    size_t i_next = 0;
132
133    std::vector<llama_ubatch> ubatches;
134
135    const llama_memory_context_ptr ctx_attn;
136    const llama_memory_context_ptr ctx_recr;
137
138    const llama_memory_status status;
139};