1#include "ggml.h"
  2#include "models.h"
  3
  4#define CHUNK_SIZE 64
  5
  6llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
  7    llm_graph_context_mamba(params), model(model) {
  8    const int64_t n_embd_head = hparams.n_embd_head_v;
  9
 10    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 11
 12    int sections[4];
 13    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 14
 15    ggml_tensor * cur;
 16    ggml_tensor * inpL;
 17
 18    inpL = build_inp_embd(model.tok_embd);
 19
 20    cb(inpL, "model.input_embed", -1);
 21
 22    auto * inp = build_inp_mem_hybrid();
 23
 24    ggml_tensor * inp_pos     = build_inp_pos();
 25    ggml_tensor * inp_out_ids = build_inp_out_ids();
 26
 27    ggml_tensor * causal_mask =
 28        ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
 29                    GGML_TRI_TYPE_LOWER);
 30
 31    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
 32    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
 33
 34    ggml_build_forward_expand(gf, causal_mask);
 35    ggml_build_forward_expand(gf, identity);
 36    ggml_build_forward_expand(gf, diag_mask);
 37
 38    for (int il = 0; il < n_layer; ++il) {
 39        ggml_tensor * inpSA = inpL;
 40
 41        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
 42        cb(cur, "attn_norm", il);
 43
 44        // Determine layer type and build appropriate attention mechanism
 45        if (hparams.is_recurrent(il)) {
 46            // Linear attention layer (gated delta net)
 47            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
 48        } else {
 49            // Full attention layer
 50            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
 51        }
 52
 53        if (il == n_layer - 1 && inp_out_ids) {
 54            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
 55            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
 56        }
 57
 58        // Residual connection
 59        cur = ggml_add(ctx0, cur, inpSA);
 60        cb(cur, "attn_residual", il);
 61
 62        // Save the tensor before post-attention norm for residual connection
 63        ggml_tensor * ffn_residual = cur;
 64
 65        // Post-attention norm
 66        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
 67        cb(attn_post_norm, "attn_post_norm", il);
 68
 69        // MOE FFN layer
 70        cur = build_layer_ffn(attn_post_norm, il);
 71        cb(cur, "ffn_out", il);
 72
 73        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
 74        cur = ggml_add(ctx0, cur, ffn_residual);
 75        cb(cur, "post_moe", il);
 76
 77        // Input for next layer
 78        inpL = cur;
 79    }
 80    cur = inpL;
 81
 82    // Final norm
 83    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
 84
 85    cb(cur, "result_norm", -1);
 86    res->t_embd = cur;
 87
 88    // LM head
 89    cur = build_lora_mm(model.output, cur);
 90
 91    cb(cur, "result_output", -1);
 92    res->t_logits = cur;
 93
 94    ggml_build_forward_expand(gf, cur);
 95}
 96
 97// utility to get one slice from the third dimension
 98// input dim:  [x, y, c, b]
 99// output dim: [x, y, 1, b]
100static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
101    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
102        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
103}
104
105std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_delta_net_chunking(
106        ggml_tensor * q,
107        ggml_tensor * k,
108        ggml_tensor * v,
109        ggml_tensor * g,
110        ggml_tensor * beta,
111        ggml_tensor * state,
112        ggml_tensor * causal_mask,
113        ggml_tensor * identity,
114        ggml_tensor * diag_mask,
115        int           il) {
116    const int64_t S_k      = q->ne[0];
117    const int64_t H_k      = q->ne[1];
118    const int64_t n_tokens = q->ne[2];
119    const int64_t n_seqs   = q->ne[3];
120
121    const int64_t S_v = v->ne[0];
122    const int64_t H_v = v->ne[1];
123
124    GGML_ASSERT(v->ne[2] == n_tokens);
125    GGML_ASSERT(k->ne[2] == n_tokens);
126    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
127    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
128    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
129
130    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
131    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
132
133    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
134
135    const float eps_norm = hparams.f_norm_rms_eps;
136
137    q = ggml_l2_norm(ctx0, q, eps_norm);
138    k = ggml_l2_norm(ctx0, k, eps_norm);
139
140    const float scale = 1.0f / sqrtf(S_v);
141
142    q = ggml_scale(ctx0, q, scale);
143
144    beta = ggml_sigmoid(ctx0, beta);
145
146    cb(q, "q_in", il);
147    cb(k, "k_in", il);
148    cb(v, "v_in", il);
149    cb(beta, "beta_in", il);
150    cb(g, "g_in", il);
151
152    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
153    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
154    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
155    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
156
157    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
158    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
159
160    cb(q, "q_perm", il);
161    cb(k, "k_perm", il);
162    cb(v, "v_perm", il);
163    cb(beta, "beta_perm", il);
164    cb(g, "g_perm", il);
165    cb(state, "state_in", il);
166
167    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
168    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
169    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
170    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
171
172    // Do padding
173    const int64_t chunk_size = CHUNK_SIZE;
174
175    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
176    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
177
178    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
179    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
180    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
181    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
182    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
183
184    cb(q, "q_pad", il);
185    cb(k, "k_pad", il);
186    cb(v, "v_pad", il);
187    cb(beta, "beta_pad", il);
188    cb(g, "g_pad", il);
189
190    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
191    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
192
193    cb(v_beta, "v_beta", il);
194    cb(k_beta, "k_beta", il);
195
196    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
197    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
198    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
199    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
200    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
201
202    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
203    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
204
205    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
206    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
207
208    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
209    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
210
211    ggml_tensor * gcs_j_broadcast =
212        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
213
214    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
215    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
216
217    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
218    decay_mask = ggml_exp(ctx0, decay_mask);
219    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
220
221    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
222
223    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
224    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
225    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
226
227    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
228    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
229
230    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
231    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
232    attn                     = ggml_add(ctx0, attn, identity);
233    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
234
235    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
236
237    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
238    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
239
240    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
241    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
242
243    ggml_tensor * k_cumdecay =
244        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
245    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
246
247    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
248    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
249    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
250    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
251
252
253    // vectorized calculation of key_gdiff
254    // improved from the chunked version:
255    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
256    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
257    //   key_gdiff = key * g_diff.unsqueeze(-1)
258    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
259    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
260
261    // get last element in g_cumsum along chunk_size dimension (ne0)
262    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
263    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
264                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
265                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
266    g_last = ggml_cont(ctx0, g_last);
267    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
268
269    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
270    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
271
272    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
273    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
274
275    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
276    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
277                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
278
279    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
280    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
281
282    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
283    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
284
285
286    // state to be updated per chunk
287    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
288    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
289
290    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
291    ggml_tensor * core_attn_out = nullptr;
292
293    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
294        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
295        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
296
297        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
298        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
299
300        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
301        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
302
303        // shape: (chunk_size, 1, H_v * n_seqs)
304        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
305
306        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
307        // replaced by precomputed attn_kq
308        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
309        cb(attn_chunk, "attn_chunk", il);
310
311        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
312
313        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
314        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
315        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
316
317        // v_new = v_i - v_prime
318        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
319        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
320        cb(v_new, "v_new_chunk", il);
321
322        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
323        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
324        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
325        cb(attn_inter, "attn_inter_chunk", il);
326
327        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
328        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
329        cb(v_attn, "v_attn_chunk", il);
330
331        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
332        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
333
334        core_attn_out = core_attn_out == nullptr
335            ? core_attn_out_chunk
336            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
337
338        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
339        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
340        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
341        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
342
343        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
344        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
345        new_state = ggml_add(ctx0,
346            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
347            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
348    }
349
350    // truncate padded tokens
351    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
352            S_v, n_tokens, H_v, n_seqs,
353            ggml_row_size(core_attn_out->type, S_v),
354            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
355            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
356    output_tokens = ggml_cont(ctx0, output_tokens);
357    cb(output_tokens, "output_tokens", il);
358
359    // permute back to (S_v, H_v, n_tokens, n_seqs)
360    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
361    output_tokens = ggml_cont(ctx0, output_tokens);
362
363    return {output_tokens, new_state};
364}
365
366std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_delta_net_autoregressive(
367        ggml_tensor * q,
368        ggml_tensor * k,
369        ggml_tensor * v,
370        ggml_tensor * g,
371        ggml_tensor * beta,
372        ggml_tensor * state,
373        int           il) {
374    const int64_t S_k      = q->ne[0];
375    const int64_t H_k      = q->ne[1];
376    const int64_t n_tokens = q->ne[2];
377    const int64_t n_seqs   = q->ne[3];
378
379    const int64_t S_v = v->ne[0];
380    const int64_t H_v = v->ne[1];
381
382    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
383    GGML_ASSERT(v->ne[2] == n_tokens);
384    GGML_ASSERT(k->ne[2] == n_tokens);
385    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
386    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
387    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
388
389    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
390    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
391
392    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
393
394    const float eps_norm = hparams.f_norm_rms_eps;
395
396    q = ggml_l2_norm(ctx0, q, eps_norm);
397    k = ggml_l2_norm(ctx0, k, eps_norm);
398
399    const float scale = 1.0f / sqrtf(S_v);
400
401    q    = ggml_scale(ctx0, q, scale);
402    beta = ggml_sigmoid(ctx0, beta);
403
404    cb(q, "q_in", il);
405    cb(k, "k_in", il);
406    cb(v, "v_in", il);
407    cb(beta, "beta_in", il);
408    cb(g, "g_in", il);
409
410    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
411
412    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
413    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
414
415    // Apply exponential to g_t
416    g_t = ggml_exp(ctx0, g_t);
417
418    // Apply the gated delta rule for the single timestep
419    // last_recurrent_state = last_recurrent_state * g_t
420    state = ggml_mul(ctx0, state, g_t);
421
422    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
423    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
424    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
425    // we need to sum over dim=-2, so we transpose, sum, then transpose again
426    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
427
428    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
429    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
430    // delta = (v_t - kv_mem) * beta_t
431    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
432    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
433
434    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
435    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
436    state                   = ggml_add(ctx0, state, k_t_delta);
437
438    // Compute the attention output
439    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
440    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
441    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
442    // again, since it's over dim = -2, transpose, sum, transpose back
443    ggml_tensor * core_attn_out =
444        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
445
446    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
447    cb(core_attn_out, "output_tokens", il);
448    cb(state, "new_state", il);
449
450    return {core_attn_out, state};
451}
452
453std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_qkvz(
454                ggml_tensor * input,
455                        int   il) {
456    const int64_t n_seqs       = ubatch.n_seqs;
457    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
458
459    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
460    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
461    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
462
463    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
464    cb(z, "z", il);
465
466    return { qkv_mixed, z };
467}
468
469ggml_tensor * llm_build_qwen35moe::build_norm_gated(
470        ggml_tensor * input,
471        ggml_tensor * weights,
472        ggml_tensor * gate,
473        int           layer) {
474    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
475    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
476
477    return ggml_mul(ctx0, normalized, gated_silu);
478}
479
480ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
481        llm_graph_input_attn_kv * inp,
482        ggml_tensor *             cur,
483        ggml_tensor *             inp_pos,
484        int *                     sections,
485        int                       il) {
486    const int64_t n_embd_head = hparams.n_embd_head_v;
487    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
488
489    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
490
491    // Qwen3Next uses a single Q projection that outputs query + gate
492    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
493    cb(Qcur_full, "Qcur_full", il);
494
495    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
496        ggml_element_size(Qcur_full) * n_embd_head * 2,
497        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
498    cb(Qcur, "Qcur_reshaped", il);
499
500    // Apply Q normalization
501    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
502    cb(Qcur, "Qcur_normed", il);
503
504    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
505    cb(Kcur, "Kcur", il);
506
507    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
508    cb(Vcur, "Vcur", il);
509
510    // Apply K normalization
511    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
512    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
513    cb(Kcur, "Kcur_normed", il);
514
515    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
516        ggml_element_size(Qcur_full) * n_embd_head * 2,
517        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
518        ggml_element_size(Qcur_full) * n_embd_head);
519    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
520    cb(gate, "gate_reshaped", il);
521
522    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
523
524    // Apply IMRoPE
525    Qcur = ggml_rope_multi(
526            ctx0, Qcur, inp_pos, nullptr,
527            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
528            ext_factor, attn_factor, beta_fast, beta_slow
529            );
530
531    Kcur = ggml_rope_multi(
532            ctx0, Kcur, inp_pos, nullptr,
533            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
534            ext_factor, attn_factor, beta_fast, beta_slow
535            );
536
537    cb(Qcur, "Qcur", il);
538    cb(Kcur, "Kcur", il);
539    cb(Vcur, "Vcur", il);
540
541    // Attention computation
542    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
543
544    cur = build_attn(inp,
545                nullptr, nullptr,
546                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
547    cb(cur, "attn_pregate", il);
548
549    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
550    cb(gate_sigmoid, "gate_sigmoid", il);
551
552    cur = ggml_mul(ctx0, cur, gate_sigmoid);
553    cb(cur, "attn_gated", il);
554
555    cur = build_lora_mm(model.layers[il].wo, cur);
556    cb(cur, "attn_output", il);
557
558    return cur;
559}
560
561ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
562        llm_graph_input_rs * inp,
563        ggml_tensor *        cur,
564        ggml_tensor *        causal_mask,
565        ggml_tensor *        identity,
566        ggml_tensor *        diag_mask,
567        int                  il) {
568    const auto * mctx_cur = inp->mctx;
569
570    const int64_t d_inner      = hparams.ssm_d_inner;
571    const int64_t n_seqs       = ubatch.n_seqs;
572    const int64_t head_k_dim   = hparams.ssm_d_state;
573    const int64_t num_k_heads  = hparams.ssm_n_group;
574    const int64_t num_v_heads  = hparams.ssm_dt_rank;
575    const int64_t head_v_dim   = d_inner / num_v_heads;
576    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
577
578    const auto kv_head = mctx_cur->get_head();
579
580    GGML_ASSERT(n_seqs != 0);
581    GGML_ASSERT(ubatch.equal_seqs());
582    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
583
584    // Input projections
585    auto qkvz = build_qkvz(cur, il);
586    ggml_tensor * qkv_mixed = qkvz.first;
587    ggml_tensor * z         = qkvz.second;
588
589    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
590    beta  = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
591    cb(beta, "beta", il);
592    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
593    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
594    cb(alpha, "alpha", il);
595
596    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
597    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
598    cb(alpha_softplus, "a_softplus", il);
599    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
600    cb(gate, "gate", il);
601
602    // Get convolution states from cache
603    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
604    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
605
606    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
607
608    // Build the convolution states tensor
609    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
610    cb(conv_states, "conv_states", il);
611
612    // Calculate convolution kernel size
613    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
614    const int64_t conv_kernel_size = conv_kernel->ne[0];
615    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
616    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
617    cb(conv_states, "conv_states_reshaped", il);
618
619    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
620    cb(qkv_mixed, "qkv_mixed_permuted", il);
621
622    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
623    cb(conv_input, "conv_input", il);
624
625    // Update convolution state cache
626    // Extract the last (conv_kernel_size - 1) states from conv_input
627    ggml_tensor * last_conv_states =
628        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
629                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
630    cb(last_conv_states, "last_conv_states", il);
631
632    ggml_tensor * state_update_target =
633        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
634                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
635    cb(state_update_target, "state_update_target", il);
636
637    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
638    cb(conv_states_all, "conv_states_updated", il);
639
640    // Apply SSM convolution
641    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
642    cb(conv_output_proper, "conv_output_raw", il);
643
644    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
645    cb(conv_output_silu, "conv_output_silu", il);
646
647    ggml_tensor * conv_qkv_mix = conv_output_silu;
648
649    // Calculate the total conv dimension
650    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
651    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
652
653    // Extract the convolved Q, K, V from conv_output
654    ggml_tensor * q_conv =
655        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
656    cb(q_conv, "q_conv", il);
657    ggml_tensor * k_conv =
658        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
659                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
660    cb(k_conv, "k_conv", il);
661    ggml_tensor * v_conv =
662        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
663                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
664    cb(v_conv, "v_conv", il);
665
666    // Unsqueeze them
667    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
668    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
669    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
670
671    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
672    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
673    cb(state, "state_predelta", il);
674
675    // if head keys and value keys are different, repeat Q/K to match V's head count
676    // V heads are in tiled order (from conversion), so simple tiled repeat works
677    if (num_k_heads != num_v_heads) {
678        GGML_ASSERT(num_v_heads % num_k_heads == 0);
679        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
680        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
681    }
682
683    cb(q_conv, "q_conv_predelta", il);
684    cb(k_conv, "k_conv_predelta", il);
685    cb(v_conv, "v_conv_predelta", il);
686
687    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
688    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
689    if (n_seq_tokens == 1) {
690        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
691    } else {
692        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
693    }
694    ggml_tensor * output    = attn_out.first;
695    ggml_tensor * new_state = attn_out.second;
696    cb(output, "attn_output", il);
697    cb(new_state, "new_state", il);
698
699    // Update the recurrent states
700    ggml_build_forward_expand(gf,
701                              ggml_cpy(ctx0, new_state,
702                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
703                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
704
705    // Reshape both attn_out_final and z to 2D tensors for normalization
706    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
707    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
708
709    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
710    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
711
712    // Apply gated normalization: self.norm(core_attn_out, z)
713    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
714
715    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
716    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
717    cb(final_output, "final_output", il);
718
719    // Output projection
720    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
721    cb(cur, "linear_attn_out", il);
722
723    // Reshape back to original dimensions
724    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
725    return cur;
726}
727
728ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) {
729    // Check if this is an MoE layer
730    GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr);
731
732    ggml_tensor * moe_out =
733        build_moe_ffn(cur,
734            model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
735            model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
736            nullptr,
737            n_expert, n_expert_used, LLM_FFN_SILU,
738            true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
739    cb(moe_out, "ffn_moe_out", il);
740
741    // Add shared experts if present - following Qwen3Next reference implementation
742    if (model.layers[il].ffn_up_shexp != nullptr) {
743        ggml_tensor * ffn_shexp =
744            build_ffn(cur,
745                model.layers[il].ffn_up_shexp, NULL, NULL,
746                model.layers[il].ffn_gate_shexp, NULL, NULL,
747                model.layers[il].ffn_down_shexp, NULL, NULL,
748                NULL,
749                LLM_FFN_SILU, LLM_FFN_PAR, il);
750        cb(ffn_shexp, "ffn_shexp", il);
751
752        // Apply shared expert gating as in the reference implementation
753        // The shared expert has its own gate that is sigmoided
754        // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
755        ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
756        cb(shared_gate, "shared_expert_gate", il);
757
758        // Apply sigmoid to the gate
759        shared_gate = ggml_sigmoid(ctx0, shared_gate);
760        cb(shared_gate, "shared_expert_gate_sigmoid", il);
761
762
763        // Apply the gate to the shared expert output
764        ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
765        cb(ffn_shexp, "ffn_shexp_gated", il);
766
767        cur = ggml_add(ctx0, moe_out, ffn_shexp);
768        cb(cur, "ffn_out", il);
769    } else {
770        cur = moe_out;
771    }
772
773    return cur;
774}