summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-memory-hybrid-iswa.h
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/src/llama-memory-hybrid-iswa.h
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/src/llama-memory-hybrid-iswa.h')
-rw-r--r--llama.cpp/src/llama-memory-hybrid-iswa.h140
1 files changed, 140 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-memory-hybrid-iswa.h b/llama.cpp/src/llama-memory-hybrid-iswa.h
new file mode 100644
index 0000000..807c8aa
--- /dev/null
+++ b/llama.cpp/src/llama-memory-hybrid-iswa.h
@@ -0,0 +1,140 @@
1#pragma once
2
3#include "llama-batch.h"
4#include "llama-graph.h"
5#include "llama-kv-cache-iswa.h"
6#include "llama-memory.h"
7#include "llama-memory-recurrent.h"
8
9#include <memory>
10#include <vector>
11
12//
13// llama_memory_hybrid_iswa
14//
15
16// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
17// support models where each layer may be either attention-based (with SWA support) or recurrent
18
19class llama_memory_hybrid_iswa : public llama_memory_i {
20public:
21 llama_memory_hybrid_iswa(
22 const llama_model & model,
23 /* attn */
24 ggml_type type_k,
25 ggml_type type_v,
26 bool v_trans,
27 bool swa_full,
28 uint32_t kv_size,
29 uint32_t n_ubatch,
30 uint32_t n_pad,
31 /* recurrent */
32 ggml_type type_r,
33 ggml_type type_s,
34 uint32_t rs_size,
35 /* common */
36 uint32_t n_seq_max,
37 bool offload,
38 bool unified,
39 /* layer filters */
40 const layer_filter_cb & filter_attn = nullptr,
41 const layer_filter_cb & filter_recr = nullptr);
42
43 ~llama_memory_hybrid_iswa() = default;
44
45 //
46 // llama_memory_i
47 //
48
49 llama_memory_context_ptr init_batch(
50 llama_batch_allocr & balloc,
51 uint32_t n_ubatch,
52 bool embd_all) override;
53
54 llama_memory_context_ptr init_full() override;
55
56 llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
57
58 bool get_can_shift() const override;
59
60 void clear(bool data) override;
61
62 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
63 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
64 void seq_keep(llama_seq_id seq_id) override;
65 void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
66 void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
67
68 llama_pos seq_pos_min(llama_seq_id seq_id) const override;
69 llama_pos seq_pos_max(llama_seq_id seq_id) const override;
70
71 std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
72
73 // state write/load
74
75 void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
76 void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
77
78 //
79 // llama_memory_hybrid_iswa specific API
80 //
81
82 llama_kv_cache_iswa * get_mem_attn() const;
83 llama_memory_recurrent * get_mem_recr() const;
84
85private:
86 const llama_hparams & hparams;
87
88 const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
89 const std::unique_ptr<llama_memory_recurrent> mem_recr;
90};
91
92class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
93public:
94 using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
95
96 // init failure
97 explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
98
99 // init full
100 explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
101
102 // init update
103 explicit llama_memory_hybrid_iswa_context(
104 llama_memory_hybrid_iswa * mem,
105 llama_context * lctx,
106 bool optimize);
107
108 // init success
109 llama_memory_hybrid_iswa_context(
110 llama_memory_hybrid_iswa * mem,
111 slot_info_vec_t sinfos_base,
112 slot_info_vec_t sinfos_swa,
113 std::vector<llama_ubatch> ubatches);
114
115 ~llama_memory_hybrid_iswa_context() = default;
116
117 bool next() override;
118 bool apply() override;
119
120 llama_memory_status get_status() const override;
121 const llama_ubatch & get_ubatch() const override;
122
123 //
124 // llama_memory_hybrid_iswa_context
125 //
126
127 const llama_kv_cache_iswa_context * get_attn() const;
128 const llama_memory_recurrent_context * get_recr() const;
129
130private:
131 // the index of the next ubatch to process
132 size_t i_next = 0;
133
134 std::vector<llama_ubatch> ubatches;
135
136 const llama_memory_context_ptr ctx_attn;
137 const llama_memory_context_ptr ctx_recr;
138
139 const llama_memory_status status;
140};