1#include "models.h"
2#include "ggml.h"
3
4#define CHUNK_SIZE 64
5
6// Causal Conv1d function for Q,K,V
7// When qkv is 0, it is Q, 1 is K, 2 is V
8static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
9 const int64_t d_inner = head_dim * n_head;
10 const int64_t conv_state_size = (d_conv - 1) * d_inner;
11 const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V
12
13 // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
14 // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
15 // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
16 // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
17 // View Q conv state: offset 0, size conv_state_size per seq
18 // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
19 // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
20 // We want [d_conv-1, d_inner, n_seqs] view:
21 // nb1 = (d_conv-1) * element_size (stride between channels)
22 // nb2 = n_embd_r_total * element_size (stride between seqs)
23 ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
24 (d_conv - 1) * ggml_element_size(conv_state_all), // nb1: stride between channels
25 n_embd_r_total * ggml_element_size(conv_state_all), // nb2: stride between seqs
26 qkv * conv_state_size * ggml_element_size(conv_state_all));
27
28// Causal Conv1d function for Q,K,V
29// When qkv is 0, it is Q, 1 is K, 2 is V
30 // Step 1: Q, K, V projections -> [d_inner, n_tokens]
31 ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
32
33 // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
34 ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
35
36 // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
37 ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
38
39 // Save last (d_conv-1) columns back to Q conv state
40 ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
41 conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
42 ggml_build_forward_expand(gf,
43 ggml_cpy(ctx0, last_conv_x,
44 ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs,
45 (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));
46 // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
47 // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
48 // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
49 // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
50 // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
51 // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
52 ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
53
54 // Apply conv1d
55 // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
56 ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
57 // Reshape to 2D for bias add: {d_inner, n_tokens}
58 Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
59 Xcur = ggml_silu(ctx0, Xcur);
60
61 return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
62}
63
64llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
65 llm_graph_context_mamba(params), model(model) {
66 ggml_tensor * cur;
67 ggml_tensor * inpL;
68
69 inpL = build_inp_embd(model.tok_embd);
70 cb(inpL, "model.embed_tokens", -1);
71
72 // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
73 // So we don't need inp_pos
74
75 auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
76 auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
77 auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
78 auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
79 auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
80
81 // Output ids for selecting which tokens to output
82 ggml_tensor * inp_out_ids = build_inp_out_ids();
83
84 ggml_tensor * chunked_causal_mask =
85 ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
86 GGML_TRI_TYPE_LOWER);
87
88 ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
89 ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
90
91 ggml_build_forward_expand(gf, chunked_causal_mask);
92 ggml_build_forward_expand(gf, chunked_identity);
93 ggml_build_forward_expand(gf, chunked_diag_mask);
94
95 // Kimi dimension constants
96 const int64_t n_head = hparams.n_head();
97 const int64_t head_dim = hparams.n_embd_head_kda;
98 const int64_t d_conv = hparams.ssm_d_conv;
99 const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096
100 const int64_t n_seqs = ubatch.n_seqs;
101 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
102
103 // Verify batch consistency for recurrent layers
104 GGML_ASSERT(n_seqs != 0);
105 GGML_ASSERT(ubatch.equal_seqs());
106 GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
107
108 // MLA params
109 const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
110 const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
111 const int64_t kv_lora_rank = hparams.n_lora_kv;
112 // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
113 // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
114 const int64_t n_embd_head_qk_rope = hparams.n_rot; // config.qk_rope_head_dim
115 const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128
116 // Attention scale for MLA
117 const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
118
119 for (int il = 0; il < n_layer; ++il) {
120 const auto & layer = model.layers[il];
121 ggml_tensor * inpSA = inpL;
122
123 // Attention Norm
124 cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
125 cb(cur, "attn_norm", il);
126
127 // Check layer type by checking which tensors exist
128 // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
129 bool is_kda = (layer.ssm_a != nullptr);
130 bool is_mla = (layer.wkv_a_mqa != nullptr);
131
132 if (is_kda) {
133 // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
134 // Reference: vLLM kda.py
135 const auto * mctx_cur = inp_rs->mctx;
136 const auto kv_head = mctx_cur->get_head();
137
138 // Get conv states from r_l tensor (Q, K, V each have separate state)
139 ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
140 cb(conv_states_all, "conv_states_all", il);
141 ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
142 ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
143 ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
144 ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
145
146 // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
147 ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
148 ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
149 cb(g1, "g1 f_b(f_a(cur))", il);
150 g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
151 g1 = ggml_softplus(ctx0, g1);
152 g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
153
154 // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
155 // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
156 ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
157 g1 = ggml_mul(ctx0, g1, A);
158 cb(g1, "kda_g1", il);
159
160 // Compute beta (mixing coefficient)
161 ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
162 beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
163 cb(beta, "kda_beta", il);
164
165 // Reshape for KDA recurrence
166 // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
167 cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
168
169 g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
170
171 // Get SSM state and compute KDA recurrence using ggml_kda_scan
172 ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
173 ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
174 state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
175 // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
176 std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
177 build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
178 build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
179
180 ggml_tensor * output = attn_out.first;
181 ggml_tensor * new_state = attn_out.second;
182 cb(output, "attn_output", il);
183 cb(new_state, "new_state", il);
184
185 // Update the recurrent states
186 ggml_build_forward_expand(gf,
187 ggml_cpy(ctx0, new_state,
188 ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
189 kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
190
191 // Output gating g2 = g_b(g_a(x))
192 ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
193 ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
194 ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
195 cb(g2, "g2 g_b(g_a(cur_2d))", il);
196 g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
197
198 // Apply o_norm with sigmoid gating
199 // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
200 // Formula: output = RMSNorm(x) * sigmoid(g)
201 ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs);
202 ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
203 cb(normed, "kda_normed", il);
204 ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
205 ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
206
207 // Output projection
208 gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
209 cur = ggml_mul_mat(ctx0, layer.wo, gated);
210 cb(cur, "kda_out", il);
211
212 } else if (is_mla) {
213 // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
214 // Reference: vLLM mla.py
215 // Step 1: Q projection and reshape
216 // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
217 // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
218 ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
219
220 // Step 2: KV compression
221 // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
222 ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
223
224 // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
225 ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
226 ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
227 ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
228 ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
229 ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
230 ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
231 // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
232 // k_pe is used directly without RoPE
233 // Normalize kv_c
234 kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
235
236 if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
237 // extract q_nope
238 ggml_tensor * q_nope =
239 ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
240 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
241 cb(q_nope, "q_nope", il);
242
243 // and {n_embd_head_qk_rope, n_head, n_tokens}
244 ggml_tensor * q_pe = ggml_view_3d(
245 ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
246 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
247 cb(q_pe, "q_pe", il);
248
249 // {n_embd_head_qk_nope, n_tokens, n_head}
250 q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
251 cb(q_nope, "q_nope_perm", il);
252
253 // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
254 ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
255 cb(q_nope_absorbed, "q_nope_absorbed", il);
256
257 // {kv_lora_rank, n_head, n_tokens}
258 q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
259 cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
260
261 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
262 // note: rope must go first for in-place context shifting in build_rope_shift()
263 Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
264 cb(Qcur, "Qcur", il);
265
266 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
267 cb(kv_cmpr, "kv_cmpr_reshape", il);
268
269 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
270 ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
271 cb(Kcur, "Kcur", il);
272
273 // {kv_lora_rank, 1, n_tokens}
274 ggml_tensor * Vcur = kv_cmpr;
275 cb(Vcur, "Vcur", il);
276
277 cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
278 cb(cur, "mla_out", il);
279 } else { // MLA KV cache disabled. Fall back to MHA KV cache.
280 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
281 cb(Qcur, "mla_Q", il);
282 // KV decompression: kv = kv_b_proj(kv_c_normed)
283 ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
284 const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
285
286 // Split kv into k_nope and v
287 ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
288 ggml_row_size(kv->type, kv_per_head),
289 ggml_row_size(kv->type, kv_per_head * n_head), 0);
290 ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
291 ggml_row_size(kv->type, kv_per_head),
292 ggml_row_size(kv->type, kv_per_head * n_head),
293 ggml_row_size(kv->type, n_embd_head_qk_nope));
294 Vcur = ggml_cont(ctx0, Vcur);
295 cb(Vcur, "mla_V", il);
296
297 // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
298 // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
299 // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
300 // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
301 ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
302 ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
303 ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
304 cb(Kcur, "mla_K", il);
305
306 // Direct softmax attention (with MHA KV cache)
307 // Use build_attn with inp_attn for proper mask handling
308 cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
309 cb(cur, "mla_out", il);
310 }
311 } else {
312 // Unknown layer type - this should not happen
313 GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
314 }
315
316 // On last layer, select only the output tokens
317 if (il == n_layer - 1 && inp_out_ids) {
318 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
319 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
320 }
321
322 // Residual
323 ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
324 cb(ffn_inp, "ffn_inp", il);
325
326 // FFN Norm
327 cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
328 cb(cur, "ffn_norm", il);
329
330 if ((uint32_t) il < hparams.n_layer_dense_lead) {
331 // Dense FFN layer
332 cur = build_ffn(cur,
333 layer.ffn_up, NULL, NULL,
334 layer.ffn_gate, NULL, NULL,
335 layer.ffn_down, NULL, NULL,
336 NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
337 cb(cur, "ffn_out", il);
338 } else {
339 // MoE layer
340 // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
341 ggml_tensor * moe_out = build_moe_ffn(cur,
342 layer.ffn_gate_inp,
343 layer.ffn_up_exps,
344 layer.ffn_gate_exps,
345 layer.ffn_down_exps,
346 layer.ffn_exp_probs_b,
347 hparams.n_expert,
348 hparams.n_expert_used,
349 LLM_FFN_SILU, true,
350 true, hparams.expert_weights_scale,
351 (llama_expert_gating_func_type) hparams.expert_gating_func,
352 il);
353 cb(moe_out, "ffn_moe_out", il);
354
355 // Shared expert
356 {
357 ggml_tensor * ffn_shexp = build_ffn(cur,
358 layer.ffn_up_shexp, NULL, NULL,
359 layer.ffn_gate_shexp, NULL, NULL,
360 layer.ffn_down_shexp, NULL, NULL,
361 NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
362 cb(ffn_shexp, "ffn_shexp", il);
363
364 cur = ggml_add(ctx0, moe_out, ffn_shexp);
365 cb(cur, "ffn_out", il);
366 }
367 }
368 // Residual
369 cur = ggml_add(ctx0, cur, ffn_inp);
370
371 cur = build_cvec(cur, il);
372 cb(cur, "l_out", il);
373
374 inpL = cur;
375 }
376 cur = inpL;
377
378 // Final Norm
379 cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
380
381 cb(cur, "result_norm", -1);
382 res->t_embd = cur;
383
384 // Output
385 cur = ggml_mul_mat(ctx0, model.output, cur);
386 cb(cur, "result_output", -1);
387 res->t_logits = cur;
388
389 ggml_build_forward_expand(gf, cur);
390}
391
392/*
393 This is a ggml implementation of the naive_chunk_kda function of
394 https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
395*/
396std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunking(
397 ggml_tensor * q,
398 ggml_tensor * k,
399 ggml_tensor * v,
400 ggml_tensor * gk,
401 ggml_tensor * beta,
402 ggml_tensor * state,
403 ggml_tensor * causal_mask,
404 ggml_tensor * identity,
405 ggml_tensor * diag_mask,
406 int il) {
407 GGML_ASSERT(ggml_is_contiguous(state));
408
409 const int64_t S_k = q->ne[0];
410 const int64_t H_k = q->ne[1];
411 const int64_t n_tokens = q->ne[2];
412 const int64_t n_seqs = q->ne[3];
413
414 const int64_t S_v = v->ne[0];
415 const int64_t H_v = v->ne[1];
416
417 GGML_ASSERT(v->ne[2] == n_tokens);
418 GGML_ASSERT(k->ne[2] == n_tokens);
419 GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
420 GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
421 GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
422
423 GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
424 GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
425
426 GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
427
428 // TODO: can this ever be false?
429 const bool use_qk_l2norm = true;
430
431 if (use_qk_l2norm) {
432 const float eps_norm = hparams.f_norm_rms_eps;
433
434 q = ggml_l2_norm(ctx0, q, eps_norm);
435 k = ggml_l2_norm(ctx0, k, eps_norm);
436 }
437
438 const float scale = 1.0f / sqrtf(S_v);
439
440 beta = ggml_sigmoid(ctx0, beta);
441
442 cb(q, "q_in", il);
443 cb(k, "k_in", il);
444 cb(v, "v_in", il);
445 cb(beta, "beta_in", il);
446 cb(gk, "gk_in", il);
447
448 q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
449 k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
450 v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
451 gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
452
453 beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
454 state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
455
456 cb(q, "q_perm", il);
457 cb(k, "k_perm", il);
458 cb(v, "v_perm", il);
459 cb(beta, "beta_perm", il);
460 cb(gk, "gk_perm", il);
461 cb(state, "state_in", il);
462
463 GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
464 GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
465 GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
466 GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
467
468 // Do padding
469 const int64_t chunk_size = CHUNK_SIZE;
470
471 const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
472 const int64_t n_chunks = (n_tokens + pad) / chunk_size;
473
474 q = ggml_pad(ctx0, q, 0, pad, 0, 0);
475 k = ggml_pad(ctx0, k, 0, pad, 0, 0);
476 v = ggml_pad(ctx0, v, 0, pad, 0, 0);
477 gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
478 beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
479
480 cb(q, "q_pad", il);
481 cb(k, "k_pad", il);
482 cb(v, "v_pad", il);
483 cb(beta, "beta_pad", il);
484 cb(gk, "gk_pad", il);
485
486 ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
487 ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
488
489 cb(v_beta, "v_beta", il);
490 cb(k_beta, "k_beta", il);
491
492 const int64_t HB = H_k * n_seqs;
493
494 q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB);
495 k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB);
496 k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
497 v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB);
498 v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
499
500 gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
501 beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
502
503 // switch for cumsum
504 gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
505 cb(gk, "gk", il);
506 ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
507 cb(gk_cumsum, "gk_cumsum", il);
508
509/*
510 Compute Akk and Aqk loop together
511 Akk loop:
512 for i in range(BT):
513 k_i = k[..., i, :] # k_i [B,H,NT,S]
514 g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
515 A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
516 Aqk loop:
517 for j in range(BT):
518 k_j = k[:, :, i, j]
519 g_j = g[:, :, i, j:j+1, :]
520 A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
521*/
522 const int64_t CHB = n_chunks * H_k * n_seqs;
523 ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
524 ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB]
525
526 ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
527 // decay_mask [chunk_size,chunk_size,S_k,CHB]
528 ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
529 cb(decay_mask, "decay_mask", il);
530
531 decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
532 cb(decay_mask, "decay_masked", il);
533 decay_mask = ggml_exp(ctx0, decay_mask);
534 decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
535
536 // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
537 decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
538
539 ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
540 ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
541 ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
542
543 ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
544 ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
545
546 // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
547 ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
548 ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
549 Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
550 Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
551 cb(Akk, "Akk", il);
552 cb(Aqk, "Aqk", il);
553
554 Akk = ggml_mul(ctx0, Akk, beta);
555 Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
556 cb(Akk, "attn_pre_solve", il);
557
558 Aqk = ggml_mul(ctx0, Aqk, diag_mask);
559 Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
560 cb(Aqk, "Aqk_masked", il);
561
562 // for i in range(1, chunk_size):
563 // row = attn[..., i, :i].clone()
564 // sub = attn[..., :i, :i].clone()
565 // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
566 // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
567 //
568 // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
569 ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
570 ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
571
572 ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
573 Akk = ggml_mul(ctx0, lin_solve, causal_mask);
574 Akk = ggml_add(ctx0, Akk, identity);
575
576 cb(Akk, "attn_solved", il);
577
578 // switch back for downstream
579 gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
580 ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum);
581 cb(gk_cumsum, "gk_cumsum", il);
582
583 // u = (A*beta[..., None, :]) @ v aka U_[t]
584 ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
585
586 ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
587 cb(kbeta_gkexp, "kbeta_gkexp", il);
588
589 ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
590 cb(k_cumdecay, "k_cumdecay", il);
591
592 ggml_tensor * core_attn_out = nullptr;
593 ggml_tensor * new_state = ggml_dup(ctx0, state);
594
595 cb(new_state, "new_state", il);
596
597 for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
598// extract one chunk worth of data
599 auto chunkify = [=](ggml_tensor * t) {
600 return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
601 t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
602 };
603 auto chunkify_A = [=](ggml_tensor * t) {
604 return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
605 t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
606 };
607
608
609// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
610 ggml_tensor * k_chunk = chunkify(k);
611 ggml_tensor * q_chunk = chunkify(q);
612 ggml_tensor * vb_chunk = chunkify(vb);
613
614// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
615 ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
616 ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
617 ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
618 ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
619
620 ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
621
622 // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
623 // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
624 ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
625
626 // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
627 ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
628 ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
629
630 // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
631 // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
632 // or Gamma_[t]*Q_]t] @ S
633 ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk);
634 ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
635 attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
636
637 // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
638 // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
639 ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
640
641 // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
642 ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
643
644 core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
645
646 ggml_tensor * gk_cum_last =
647 ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
648 gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
649 gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
650
651 ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
652
653 ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
654
655 ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
656
657 ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
658
659 // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
660 ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
661
662 new_state = ggml_add(ctx0,
663 ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
664 ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
665 }
666
667 core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
668
669 // truncate padded tokens
670 ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
671 S_v, n_tokens, H_v, n_seqs,
672 ggml_row_size(core_attn_out->type, S_v),
673 ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
674 ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
675 output_tokens = ggml_cont(ctx0, output_tokens);
676 // permute back to (S_v, H_v, n_tokens, n_seqs)
677 output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
678 output_tokens = ggml_cont(ctx0, output_tokens);
679
680 cb(new_state, "output_state", il);
681
682 return {output_tokens, new_state};
683}
684
685std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_autoregressive(
686 ggml_tensor * q,
687 ggml_tensor * k,
688 ggml_tensor * v,
689 ggml_tensor * gk,
690 ggml_tensor * beta,
691 ggml_tensor * state,
692 int il) {
693 GGML_ASSERT(ggml_is_contiguous(v));
694 GGML_ASSERT(ggml_is_contiguous(gk));
695
696 const int64_t S_k = q->ne[0];
697 const int64_t H_k = q->ne[1];
698 const int64_t n_tokens = q->ne[2];
699 const int64_t n_seqs = q->ne[3];
700
701 const int64_t S_v = v->ne[0];
702 const int64_t H_v = v->ne[1];
703
704 GGML_ASSERT(n_tokens == 1);
705 GGML_ASSERT(v->ne[2] == n_tokens);
706 GGML_ASSERT(k->ne[2] == n_tokens);
707 GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
708 GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
709 GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
710
711 GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
712 GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
713
714 GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
715
716 const float eps_norm = hparams.f_norm_rms_eps;
717
718 q = ggml_l2_norm(ctx0, q, eps_norm);
719 k = ggml_l2_norm(ctx0, k, eps_norm);
720
721 const float scale = 1.0f / sqrtf(S_v);
722
723 q = ggml_scale(ctx0, q, scale);
724 beta = ggml_sigmoid(ctx0, beta);
725
726 cb(q, "q_in", il);
727 cb(k, "k_in", il);
728 cb(v, "v_in", il);
729 cb(beta, "beta_in", il);
730 cb(gk, "gk_in", il);
731
732// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
733// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
734// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
735 gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
736 ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
737 ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
738
739 // Apply exponential to gk_t
740 gk_t = ggml_exp(ctx0, gk_t);
741 // Apply the gated delta rule for the single timestep
742 // last_recurrent_state = last_recurrent_state * gk_t
743 // S = S * g_i[..., None].exp()
744 state = ggml_mul(ctx0, state, gk_t);
745
746 ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
747
748// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
749 k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
750 ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
751
752 // v_i - (k_i[..., None] * S).sum(-2)
753 v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
754 ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
755
756 // b_i[..., None] * k_i
757 ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
758
759 // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
760 // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
761 state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
762
763 q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
764 state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
765 ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
766 // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
767 cb(core_attn_out, "output_tokens", il);
768 cb(state, "new_state", il);
769
770 return {core_attn_out, state};
771}
772