1#include "ggml.h"
  2#include "models.h"
  3
  4#define CHUNK_SIZE 64
  5
  6llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
  7    llm_graph_context_mamba(params), model(model) {
  8    const int64_t n_embd_head = hparams.n_embd_head_v;
  9
 10    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 11
 12    int sections[4];
 13    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 14
 15    ggml_tensor * cur;
 16    ggml_tensor * inpL;
 17
 18    inpL = build_inp_embd(model.tok_embd);
 19
 20    cb(inpL, "model.input_embed", -1);
 21
 22    auto * inp = build_inp_mem_hybrid();
 23
 24    ggml_tensor * inp_pos     = build_inp_pos();
 25    ggml_tensor * inp_out_ids = build_inp_out_ids();
 26
 27    ggml_tensor * causal_mask =
 28        ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
 29                    GGML_TRI_TYPE_LOWER);
 30
 31    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
 32    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
 33
 34    ggml_build_forward_expand(gf, causal_mask);
 35    ggml_build_forward_expand(gf, identity);
 36    ggml_build_forward_expand(gf, diag_mask);
 37
 38    for (int il = 0; il < n_layer; ++il) {
 39        ggml_tensor * inpSA = inpL;
 40
 41        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
 42        cb(cur, "attn_norm", il);
 43
 44        // Determine layer type and build appropriate attention mechanism
 45        if (hparams.is_recurrent(il)) {
 46            // Linear attention layer (gated delta net)
 47            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
 48        } else {
 49            // Full attention layer
 50            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
 51        }
 52
 53        if (il == n_layer - 1 && inp_out_ids) {
 54            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
 55            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
 56        }
 57
 58        // Residual connection
 59        cur = ggml_add(ctx0, cur, inpSA);
 60        cb(cur, "attn_residual", il);
 61
 62        // Save the tensor before post-attention norm for residual connection
 63        ggml_tensor * ffn_residual = cur;
 64
 65        // Post-attention norm
 66        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
 67        cb(attn_post_norm, "attn_post_norm", il);
 68
 69        // Dense FFN layer - without residual connection
 70        cur = build_layer_ffn(attn_post_norm, il);
 71        cb(cur, "ffn_out", il);
 72
 73        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
 74        cur = ggml_add(ctx0, cur, ffn_residual);
 75        cb(cur, "post_ffn", il);
 76
 77        // Input for next layer
 78        inpL = cur;
 79    }
 80    cur = inpL;
 81
 82    // Final norm
 83    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
 84
 85    cb(cur, "result_norm", -1);
 86    res->t_embd = cur;
 87
 88    // LM head
 89    cur = build_lora_mm(model.output, cur);
 90
 91    cb(cur, "result_output", -1);
 92    res->t_logits = cur;
 93
 94    ggml_build_forward_expand(gf, cur);
 95}
 96
 97// utility to get one slice from the third dimension
 98// input dim:  [x, y, c, b]
 99// output dim: [x, y, 1, b]
100static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
101    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
102        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
103}
104
105std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_delta_net_chunking(
106        ggml_tensor * q,
107        ggml_tensor * k,
108        ggml_tensor * v,
109        ggml_tensor * g,
110        ggml_tensor * beta,
111        ggml_tensor * state,
112        ggml_tensor * causal_mask,
113        ggml_tensor * identity,
114        ggml_tensor * diag_mask,
115        int           il) {
116    const int64_t S_k      = q->ne[0];
117    const int64_t H_k      = q->ne[1];
118    const int64_t n_tokens = q->ne[2];
119    const int64_t n_seqs   = q->ne[3];
120
121    const int64_t S_v = v->ne[0];
122    const int64_t H_v = v->ne[1];
123
124    GGML_ASSERT(v->ne[2] == n_tokens);
125    GGML_ASSERT(k->ne[2] == n_tokens);
126    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
127    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
128    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
129
130    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
131    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
132
133    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
134
135    const float eps_norm = hparams.f_norm_rms_eps;
136
137    q = ggml_l2_norm(ctx0, q, eps_norm);
138    k = ggml_l2_norm(ctx0, k, eps_norm);
139
140    const float scale = 1.0f / sqrtf(S_v);
141
142    q = ggml_scale(ctx0, q, scale);
143
144    beta = ggml_sigmoid(ctx0, beta);
145
146    cb(q, "q_in", il);
147    cb(k, "k_in", il);
148    cb(v, "v_in", il);
149    cb(beta, "beta_in", il);
150    cb(g, "g_in", il);
151
152    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
153    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
154    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
155    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
156
157    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
158    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
159
160    cb(q, "q_perm", il);
161    cb(k, "k_perm", il);
162    cb(v, "v_perm", il);
163    cb(beta, "beta_perm", il);
164    cb(g, "g_perm", il);
165    cb(state, "state_in", il);
166
167    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
168    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
169    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
170    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
171
172    // Do padding
173    const int64_t chunk_size = CHUNK_SIZE;
174
175    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
176    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
177
178    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
179    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
180    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
181    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
182    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
183
184    cb(q, "q_pad", il);
185    cb(k, "k_pad", il);
186    cb(v, "v_pad", il);
187    cb(beta, "beta_pad", il);
188    cb(g, "g_pad", il);
189
190    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
191    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
192
193    cb(v_beta, "v_beta", il);
194    cb(k_beta, "k_beta", il);
195
196    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
197    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
198    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
199    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
200    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
201
202    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
203    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
204
205    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
206    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
207
208    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
209    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
210
211    ggml_tensor * gcs_j_broadcast =
212        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
213
214    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
215    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
216
217    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
218    decay_mask = ggml_exp(ctx0, decay_mask);
219    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
220
221    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
222
223    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
224    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
225    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
226
227    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
228    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
229
230    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
231    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
232    attn                     = ggml_add(ctx0, attn, identity);
233    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
234
235    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
236
237    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
238    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
239
240    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
241    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
242
243    ggml_tensor * k_cumdecay =
244        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
245    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
246
247    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
248    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
249    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
250    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
251
252
253    // vectorized calculation of key_gdiff
254    // improved from the chunked version:
255    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
256    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
257    //   key_gdiff = key * g_diff.unsqueeze(-1)
258    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
259    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
260
261    // get last element in g_cumsum along chunk_size dimension (ne0)
262    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
263    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
264                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
265                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
266    g_last = ggml_cont(ctx0, g_last);
267    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
268
269    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
270    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
271
272    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
273    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
274
275    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
276    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
277                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
278
279    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
280    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
281
282    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
283    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
284
285    // state to be updated per chunk
286    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
287    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
288
289    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
290    ggml_tensor * core_attn_out = nullptr;
291
292    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
293        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
294        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
295
296        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
297        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
298
299        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
300        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
301
302        // shape: (chunk_size, 1, H_v * n_seqs)
303        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
304
305        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
306        // replaced by precomputed attn_kq
307        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
308        cb(attn_chunk, "attn_chunk", il);
309
310        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
311
312        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
313        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
314        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
315
316        // v_new = v_i - v_prime
317        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
318        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
319        cb(v_new, "v_new_chunk", il);
320
321        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
322        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
323        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
324        cb(attn_inter, "attn_inter_chunk", il);
325
326        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
327        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
328        cb(v_attn, "v_attn_chunk", il);
329
330        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
331        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
332
333        core_attn_out = core_attn_out == nullptr
334            ? core_attn_out_chunk
335            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
336
337        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
338        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
339        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
340        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
341
342        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
343        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
344        new_state = ggml_add(ctx0,
345            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
346            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
347    }
348
349    // truncate padded tokens
350    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
351            S_v, n_tokens, H_v, n_seqs,
352            ggml_row_size(core_attn_out->type, S_v),
353            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
354            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
355    output_tokens = ggml_cont(ctx0, output_tokens);
356    cb(output_tokens, "output_tokens", il);
357
358    // permute back to (S_v, H_v, n_tokens, n_seqs)
359    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
360    output_tokens = ggml_cont(ctx0, output_tokens);
361
362    return {output_tokens, new_state};
363}
364
365std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_delta_net_autoregressive(
366        ggml_tensor * q,
367        ggml_tensor * k,
368        ggml_tensor * v,
369        ggml_tensor * g,
370        ggml_tensor * beta,
371        ggml_tensor * state,
372        int           il) {
373    const int64_t S_k      = q->ne[0];
374    const int64_t H_k      = q->ne[1];
375    const int64_t n_tokens = q->ne[2];
376    const int64_t n_seqs   = q->ne[3];
377
378    const int64_t S_v = v->ne[0];
379    const int64_t H_v = v->ne[1];
380
381    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
382    GGML_ASSERT(v->ne[2] == n_tokens);
383    GGML_ASSERT(k->ne[2] == n_tokens);
384    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
385    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
386    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
387
388    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
389    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
390
391    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
392
393    const float eps_norm = hparams.f_norm_rms_eps;
394
395    q = ggml_l2_norm(ctx0, q, eps_norm);
396    k = ggml_l2_norm(ctx0, k, eps_norm);
397
398    const float scale = 1.0f / sqrtf(S_v);
399
400    q    = ggml_scale(ctx0, q, scale);
401    beta = ggml_sigmoid(ctx0, beta);
402
403    cb(q, "q_in", il);
404    cb(k, "k_in", il);
405    cb(v, "v_in", il);
406    cb(beta, "beta_in", il);
407    cb(g, "g_in", il);
408
409    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
410
411    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
412    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
413
414    // Apply exponential to g_t
415    g_t = ggml_exp(ctx0, g_t);
416
417    // Apply the gated delta rule for the single timestep
418    // last_recurrent_state = last_recurrent_state * g_t
419    state = ggml_mul(ctx0, state, g_t);
420
421    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
422    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
423    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
424    // we need to sum over dim=-2, so we transpose, sum, then transpose again
425    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
426
427    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
428    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
429    // delta = (v_t - kv_mem) * beta_t
430    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
431    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
432
433    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
434    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
435    state                   = ggml_add(ctx0, state, k_t_delta);
436
437    // Compute the attention output
438    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
439    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
440    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
441    // again, since it's over dim = -2, transpose, sum, transpose back
442    ggml_tensor * core_attn_out =
443        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
444
445    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
446    cb(core_attn_out, "output_tokens", il);
447    cb(state, "new_state", il);
448
449    return {core_attn_out, state};
450}
451
452std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
453                ggml_tensor * input,
454                        int   il) {
455    const int64_t n_seqs       = ubatch.n_seqs;
456    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
457
458    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
459    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
460    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
461
462    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
463    cb(z, "z", il);
464
465    return { qkv_mixed, z };
466}
467
468ggml_tensor * llm_build_qwen35::build_norm_gated(
469        ggml_tensor * input,
470        ggml_tensor * weights,
471        ggml_tensor * gate,
472        int           layer) {
473    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
474    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
475
476    return ggml_mul(ctx0, normalized, gated_silu);
477}
478
479ggml_tensor * llm_build_qwen35::build_layer_attn(
480        llm_graph_input_attn_kv * inp,
481        ggml_tensor *             cur,
482        ggml_tensor *             inp_pos,
483        int *                     sections,
484        int                       il) {
485    const int64_t n_embd_head = hparams.n_embd_head_v;
486    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
487
488    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
489
490    // Qwen3Next uses a single Q projection that outputs query + gate
491    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
492    cb(Qcur_full, "Qcur_full", il);
493
494    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
495        ggml_element_size(Qcur_full) * n_embd_head * 2,
496        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
497    cb(Qcur, "Qcur_reshaped", il);
498
499    // Apply Q normalization
500    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
501    cb(Qcur, "Qcur_normed", il);
502
503    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
504    cb(Kcur, "Kcur", il);
505
506    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
507    cb(Vcur, "Vcur", il);
508
509    // Apply K normalization
510    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
511    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
512    cb(Kcur, "Kcur_normed", il);
513
514    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
515        ggml_element_size(Qcur_full) * n_embd_head * 2,
516        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
517        ggml_element_size(Qcur_full) * n_embd_head);
518    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
519    cb(gate, "gate_reshaped", il);
520
521    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
522
523    // Apply MRoPE
524    Qcur = ggml_rope_multi(
525            ctx0, Qcur, inp_pos, nullptr,
526            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
527            ext_factor, attn_factor, beta_fast, beta_slow
528            );
529
530    Kcur = ggml_rope_multi(
531            ctx0, Kcur, inp_pos, nullptr,
532            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
533            ext_factor, attn_factor, beta_fast, beta_slow
534            );
535
536    cb(Qcur, "Qcur", il);
537    cb(Kcur, "Kcur", il);
538    cb(Vcur, "Vcur", il);
539
540    // Attention computation
541    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
542
543    cur = build_attn(inp,
544                nullptr, nullptr,
545                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
546    cb(cur, "attn_pregate", il);
547
548    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
549    cb(gate_sigmoid, "gate_sigmoid", il);
550
551    cur = ggml_mul(ctx0, cur, gate_sigmoid);
552    cb(cur, "attn_gated", il);
553
554    cur = build_lora_mm(model.layers[il].wo, cur);
555    cb(cur, "attn_output", il);
556
557    return cur;
558}
559
560ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
561        llm_graph_input_rs * inp,
562        ggml_tensor *        cur,
563        ggml_tensor *        causal_mask,
564        ggml_tensor *        identity,
565        ggml_tensor *        diag_mask,
566        int                  il) {
567    const auto * mctx_cur = inp->mctx;
568
569    const int64_t d_inner      = hparams.ssm_d_inner;
570    const int64_t n_seqs       = ubatch.n_seqs;
571    const int64_t head_k_dim   = hparams.ssm_d_state;
572    const int64_t num_k_heads  = hparams.ssm_n_group;
573    const int64_t num_v_heads  = hparams.ssm_dt_rank;
574    const int64_t head_v_dim   = d_inner / num_v_heads;
575    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
576
577    const auto kv_head = mctx_cur->get_head();
578
579    GGML_ASSERT(n_seqs != 0);
580    GGML_ASSERT(ubatch.equal_seqs());
581    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
582
583    // Input projections
584    auto qkvz = build_qkvz(cur, il);
585    ggml_tensor * qkv_mixed = qkvz.first;
586    ggml_tensor * z         = qkvz.second;
587
588    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
589    beta  = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
590    cb(beta, "beta", il);
591    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
592    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
593    cb(alpha, "alpha", il);
594
595    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
596    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
597    cb(alpha_softplus, "a_softplus", il);
598    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
599    cb(gate, "gate", il);
600
601    // Get convolution states from cache
602    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
603    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
604
605    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
606
607    // Build the convolution states tensor
608    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
609    cb(conv_states, "conv_states", il);
610
611    // Calculate convolution kernel size
612    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
613    const int64_t conv_kernel_size = conv_kernel->ne[0];
614    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
615    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
616    cb(conv_states, "conv_states_reshaped", il);
617
618    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
619    cb(qkv_mixed, "qkv_mixed_permuted", il);
620
621    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
622    cb(conv_input, "conv_input", il);
623
624    // Update convolution state cache
625    // Extract the last (conv_kernel_size - 1) states from conv_input
626    ggml_tensor * last_conv_states =
627        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
628                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
629    cb(last_conv_states, "last_conv_states", il);
630
631    ggml_tensor * state_update_target =
632        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
633                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
634    cb(state_update_target, "state_update_target", il);
635
636    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
637    cb(conv_states_all, "conv_states_updated", il);
638
639    // Apply SSM convolution
640    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
641    cb(conv_output_proper, "conv_output_raw", il);
642
643    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
644    cb(conv_output_silu, "conv_output_silu", il);
645
646    ggml_tensor * conv_qkv_mix = conv_output_silu;
647
648    // Calculate the total conv dimension
649    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
650    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
651
652    // Extract the convolved Q, K, V from conv_output
653    ggml_tensor * q_conv =
654        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
655    cb(q_conv, "q_conv", il);
656    ggml_tensor * k_conv =
657        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
658                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
659    cb(k_conv, "k_conv", il);
660    ggml_tensor * v_conv =
661        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
662                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
663    cb(v_conv, "v_conv", il);
664
665    // Unsqueeze them
666    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
667    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
668    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
669
670    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
671    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
672    cb(state, "state_predelta", il);
673
674    // if head keys and value keys are different, repeat Q/K to match V's head count
675    // V heads are in tiled order (from conversion), so simple tiled repeat works
676    if (num_k_heads != num_v_heads) {
677        GGML_ASSERT(num_v_heads % num_k_heads == 0);
678        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
679        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
680    }
681
682    cb(q_conv, "q_conv_predelta", il);
683    cb(k_conv, "k_conv_predelta", il);
684    cb(v_conv, "v_conv_predelta", il);
685
686    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
687    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
688    if (n_seq_tokens == 1) {
689        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
690    } else {
691        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
692    }
693    ggml_tensor * output    = attn_out.first;
694    ggml_tensor * new_state = attn_out.second;
695    cb(output, "attn_output", il);
696    cb(new_state, "new_state", il);
697
698    // Update the recurrent states
699    ggml_build_forward_expand(gf,
700                              ggml_cpy(ctx0, new_state,
701                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
702                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
703
704    // Reshape both attn_out_final and z to 2D tensors for normalization
705    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
706    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
707
708    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
709    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
710
711    // Apply gated normalization: self.norm(core_attn_out, z)
712    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
713
714    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
715    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
716    cb(final_output, "final_output", il);
717
718    // Output projection
719    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
720    cb(cur, "linear_attn_out", il);
721
722    // Reshape back to original dimensions
723    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
724    return cur;
725}
726
727ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) {
728    // Qwen3.5 does not use MoE FFN
729    GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr);
730
731    cur = build_ffn(cur,
732        model.layers[il].ffn_up, NULL, NULL,
733        model.layers[il].ffn_gate, NULL, NULL,
734        model.layers[il].ffn_down, NULL, NULL,
735        NULL,
736        LLM_FFN_SILU, LLM_FFN_PAR, il);
737    cb(cur, "ffn_out", il);
738
739    return cur;
740}