1#include "ggml.h"
  2#include "models.h"
  3
  4#define CHUNK_SIZE 64
  5
  6llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
  7    llm_graph_context_mamba(params), model(model) {
  8    ggml_tensor * cur;
  9    ggml_tensor * inpL;
 10
 11    inpL = build_inp_embd(model.tok_embd);
 12    cb(inpL, "model.embed_tokens", -1);
 13
 14    auto * inp = build_inp_mem_hybrid();
 15
 16    ggml_tensor * inp_pos     = build_inp_pos();
 17    ggml_tensor * inp_out_ids = build_inp_out_ids();
 18
 19    ggml_tensor * causal_mask =
 20        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
 21                    GGML_TRI_TYPE_LOWER);
 22
 23    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
 24    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
 25
 26    ggml_build_forward_expand(gf, causal_mask);
 27    ggml_build_forward_expand(gf, identity);
 28    ggml_build_forward_expand(gf, diag_mask);
 29
 30    for (int il = 0; il < n_layer; ++il) {
 31        ggml_tensor * inpSA = inpL;
 32
 33        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
 34        cb(cur, "attn_norm", il);
 35
 36        // Determine layer type and build appropriate attention mechanism
 37        if (hparams.is_recurrent(il)) {
 38            // Linear attention layer (gated delta net)
 39            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
 40        } else {
 41            // Full attention layer
 42            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
 43        }
 44
 45        if (il == n_layer - 1 && inp_out_ids) {
 46            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
 47            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
 48        }
 49
 50        // Residual connection
 51        cur = ggml_add(ctx0, cur, inpSA);
 52        cb(cur, "attn_residual", il);
 53
 54        // Save the tensor before post-attention norm for residual connection
 55        ggml_tensor * ffn_residual = cur;
 56
 57        // Post-attention norm
 58        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
 59        cb(attn_post_norm, "attn_post_norm", il);
 60
 61        // FFN layer (MoE or dense) - without residual connection
 62        cur = build_layer_ffn(attn_post_norm, il);
 63        cb(cur, "ffn_out", il);
 64
 65        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
 66        cur = ggml_add(ctx0, cur, ffn_residual);
 67        cb(cur, "post_moe", il);
 68
 69        // Input for next layer
 70        inpL = cur;
 71    }
 72    cur = inpL;
 73
 74    // Final norm
 75    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
 76
 77    cb(cur, "result_norm", -1);
 78    res->t_embd = cur;
 79
 80    // LM head
 81    cur = build_lora_mm(model.output, cur);
 82
 83    cb(cur, "result_output", -1);
 84    res->t_logits = cur;
 85
 86    ggml_build_forward_expand(gf, cur);
 87}
 88
 89// utility to get one slice from the third dimension
 90// input dim:  [x, y, c, b]
 91// output dim: [x, y, 1, b]
 92static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
 93    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
 94        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
 95}
 96
 97std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
 98        ggml_tensor * q,
 99        ggml_tensor * k,
100        ggml_tensor * v,
101        ggml_tensor * g,
102        ggml_tensor * beta,
103        ggml_tensor * state,
104        ggml_tensor * causal_mask,
105        ggml_tensor * identity,
106        ggml_tensor * diag_mask,
107        int           il) {
108    const int64_t S_k      = q->ne[0];
109    const int64_t H_k      = q->ne[1];
110    const int64_t n_tokens = q->ne[2];
111    const int64_t n_seqs   = q->ne[3];
112
113    const int64_t S_v = v->ne[0];
114    const int64_t H_v = v->ne[1];
115
116    GGML_ASSERT(v->ne[2] == n_tokens);
117    GGML_ASSERT(k->ne[2] == n_tokens);
118    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
119    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
120    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
121
122    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
123    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
124
125    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
126
127    const float eps_norm = hparams.f_norm_rms_eps;
128
129    q = ggml_l2_norm(ctx0, q, eps_norm);
130    k = ggml_l2_norm(ctx0, k, eps_norm);
131
132    const float scale = 1.0f / sqrtf(S_v);
133
134    q = ggml_scale(ctx0, q, scale);
135
136    beta = ggml_sigmoid(ctx0, beta);
137
138    cb(q, "q_in", il);
139    cb(k, "k_in", il);
140    cb(v, "v_in", il);
141    cb(beta, "beta_in", il);
142    cb(g, "g_in", il);
143
144    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
145    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
146    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
147    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
148
149    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
150    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
151
152    cb(q, "q_perm", il);
153    cb(k, "k_perm", il);
154    cb(v, "v_perm", il);
155    cb(beta, "beta_perm", il);
156    cb(g, "g_perm", il);
157    cb(state, "state_in", il);
158
159    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
160    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
161    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
162    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
163
164    // Do padding
165    const int64_t chunk_size = CHUNK_SIZE;
166
167    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
168    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
169
170    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
171    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
172    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
173    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
174    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
175
176    cb(q, "q_pad", il);
177    cb(k, "k_pad", il);
178    cb(v, "v_pad", il);
179    cb(beta, "beta_pad", il);
180    cb(g, "g_pad", il);
181
182    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
183    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
184
185    cb(v_beta, "v_beta", il);
186    cb(k_beta, "k_beta", il);
187
188    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
189    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
190    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
191    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
192    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
193
194    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
195    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
196
197    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
198    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
199
200    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
201    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
202
203    ggml_tensor * gcs_j_broadcast =
204        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
205
206    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
207    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
208
209    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
210    decay_mask = ggml_exp(ctx0, decay_mask);
211    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
212
213    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
214
215    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
216    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
217    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
218
219    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
220    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
221
222    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
223    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
224    attn                     = ggml_add(ctx0, attn, identity);
225    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
226
227    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
228
229    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
230    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
231
232    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
233    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
234
235    ggml_tensor * k_cumdecay =
236        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
237    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
238
239    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
240    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
241    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
242    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
243
244
245    // vectorized calculation of key_gdiff
246    // improved from the chunked version:
247    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
248    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
249    //   key_gdiff = key * g_diff.unsqueeze(-1)
250    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
251    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
252
253    // get last element in g_cumsum along chunk_size dimension (ne0)
254    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
255    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
256                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
257                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
258    g_last = ggml_cont(ctx0, g_last);
259    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
260
261    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
262    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
263
264    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
265    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
266
267    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
268    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
269                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
270
271    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
272    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
273
274    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
275    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
276
277
278    // state to be updated per chunk
279    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
280    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
281
282    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
283    ggml_tensor * core_attn_out = nullptr;
284
285    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
286        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
287        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
288
289        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
290        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
291
292        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
293        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
294
295        // shape: (chunk_size, 1, H_v * n_seqs)
296        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
297
298        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
299        // replaced by precomputed attn_kq
300        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
301        cb(attn_chunk, "attn_chunk", il);
302
303        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
304
305        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
306        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
307        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
308
309        // v_new = v_i - v_prime
310        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
311        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
312        cb(v_new, "v_new_chunk", il);
313
314        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
315        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
316        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
317        cb(attn_inter, "attn_inter_chunk", il);
318
319        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
320        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
321        cb(v_attn, "v_attn_chunk", il);
322
323        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
324        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
325
326        core_attn_out = core_attn_out == nullptr
327            ? core_attn_out_chunk
328            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
329
330        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
331        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
332        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
333        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
334
335        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
336        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
337        new_state = ggml_add(ctx0,
338            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
339            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
340    }
341
342    // truncate padded tokens
343    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
344            S_v, n_tokens, H_v, n_seqs,
345            ggml_row_size(core_attn_out->type, S_v),
346            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
347            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
348    output_tokens = ggml_cont(ctx0, output_tokens);
349    cb(output_tokens, "output_tokens", il);
350
351    // permute back to (S_v, H_v, n_tokens, n_seqs)
352    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
353    output_tokens = ggml_cont(ctx0, output_tokens);
354
355    return {output_tokens, new_state};
356}
357
358std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
359        ggml_tensor * q,
360        ggml_tensor * k,
361        ggml_tensor * v,
362        ggml_tensor * g,
363        ggml_tensor * beta,
364        ggml_tensor * state,
365        int           il) {
366    const int64_t S_k      = q->ne[0];
367    const int64_t H_k      = q->ne[1];
368    const int64_t n_tokens = q->ne[2];
369    const int64_t n_seqs   = q->ne[3];
370
371    const int64_t S_v = v->ne[0];
372    const int64_t H_v = v->ne[1];
373
374    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
375    GGML_ASSERT(v->ne[2] == n_tokens);
376    GGML_ASSERT(k->ne[2] == n_tokens);
377    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
378    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
379    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
380
381    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
382    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
383
384    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
385
386    const float eps_norm = hparams.f_norm_rms_eps;
387
388    q = ggml_l2_norm(ctx0, q, eps_norm);
389    k = ggml_l2_norm(ctx0, k, eps_norm);
390
391    const float scale = 1.0f / sqrtf(S_v);
392
393    q    = ggml_scale(ctx0, q, scale);
394    beta = ggml_sigmoid(ctx0, beta);
395
396    cb(q, "q_in", il);
397    cb(k, "k_in", il);
398    cb(v, "v_in", il);
399    cb(beta, "beta_in", il);
400    cb(g, "g_in", il);
401
402    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
403
404    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
405    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
406
407    // Apply exponential to g_t
408    g_t = ggml_exp(ctx0, g_t);
409
410    // Apply the gated delta rule for the single timestep
411    // last_recurrent_state = last_recurrent_state * g_t
412    state = ggml_mul(ctx0, state, g_t);
413
414    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
415    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
416    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
417    // we need to sum over dim=-2, so we transpose, sum, then transpose again
418    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
419
420    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
421    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
422    // delta = (v_t - kv_mem) * beta_t
423    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
424    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
425
426    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
427    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
428    state                   = ggml_add(ctx0, state, k_t_delta);
429
430    // Compute the attention output
431    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
432    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
433    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
434    // again, since it's over dim = -2, transpose, sum, transpose back
435    ggml_tensor * core_attn_out =
436        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
437
438    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
439    cb(core_attn_out, "output_tokens", il);
440    cb(state, "new_state", il);
441
442    return {core_attn_out, state};
443}
444
445ggml_tensor * llm_build_qwen3next::build_norm_gated(
446        ggml_tensor * input,
447        ggml_tensor * weights,
448        ggml_tensor * gate,
449        int           layer) {
450    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
451    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
452
453    return ggml_mul(ctx0, normalized, gated_silu);
454}
455
456ggml_tensor * llm_build_qwen3next::build_layer_attn(
457        llm_graph_input_attn_kv * inp,
458        ggml_tensor *             cur,
459        ggml_tensor *             inp_pos,
460        int                       il) {
461    const int64_t n_embd_head = hparams.n_embd_head_v;
462    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
463
464    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
465
466    // Qwen3Next uses a single Q projection that outputs query + gate
467    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
468    cb(Qcur_full, "Qcur_full", il);
469
470    Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
471
472    // Split Q projection into query and gate
473    // The split should be along dimension 0 (the feature dimension)
474    ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
475                                             Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
476    ggml_tensor * gate =
477        ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
478                     Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
479    cb(Qcur, "Qcur", il);
480    cb(gate, "gate", il);
481
482    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
483    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
484    cb(Qcur, "Qcur_reshaped", il);
485
486    // Apply Q normalization
487    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
488    cb(Qcur, "Qcur_normed", il);
489
490    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
491    cb(Kcur, "Kcur", il);
492
493    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
494    cb(Vcur, "Vcur", il);
495
496    // Apply K normalization
497    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
498    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
499    cb(Kcur, "Kcur_normed", il);
500
501    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
502    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
503    cb(gate, "gate_reshaped", il);
504
505    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
506
507    // Apply RoPE
508    Qcur = ggml_rope_ext(
509            ctx0, Qcur, inp_pos, nullptr,
510            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
511            ext_factor, attn_factor, beta_fast, beta_slow);
512
513    Kcur = ggml_rope_ext(
514            ctx0, Kcur, inp_pos, nullptr,
515            n_rot, rope_type, n_ctx_orig, freq_base,
516            freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
517
518    cb(Qcur, "Qcur", il);
519    cb(Kcur, "Kcur", il);
520    cb(Vcur, "Vcur", il);
521
522    // Attention computation
523    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
524
525    cur = build_attn(inp,
526                nullptr, nullptr,
527                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
528    cb(cur, "attn_pregate", il);
529
530    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
531    cb(gate_sigmoid, "gate_sigmoid", il);
532
533    cur = ggml_mul(ctx0, cur, gate_sigmoid);
534    cb(cur, "attn_gated", il);
535
536    cur = build_lora_mm(model.layers[il].wo, cur);
537    cb(cur, "attn_output", il);
538
539    return cur;
540}
541
542std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
543                ggml_tensor * input,
544                        int   il) {
545    const int64_t d_inner      = hparams.ssm_d_inner;
546    const int64_t n_seqs       = ubatch.n_seqs;
547    const int64_t head_k_dim   = hparams.ssm_d_state;
548    const int64_t num_k_heads  = hparams.ssm_n_group;
549    const int64_t num_v_heads  = hparams.ssm_dt_rank;
550    const int64_t head_v_dim   = d_inner / num_v_heads;
551    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
552
553    if (model.layers[il].wqkv) {
554        // optimized path
555        ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
556        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
557        cb(qkv_mixed, "linear_attn_qkv_mixed", il);
558
559        ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
560        cb(z, "z", il);
561
562        return { qkv_mixed, z };
563
564    } else {
565        // legacy (slower) path
566        ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
567        cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
568
569        int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
570        ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
571
572        // Split mixed_qkvz into query, key, value, z
573        int64_t split_sizes_qkvz[4] = {
574            head_k_dim,                              // query size
575            head_k_dim,                              // key size
576            head_v_dim * num_v_heads / num_k_heads,  // value size
577            head_v_dim * num_v_heads / num_k_heads   // z size
578        };
579
580        ggml_tensor * query =
581            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
582                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
583        cb(query, "q", il);
584
585        ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
586                                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
587                                        split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
588        cb(key, "k", il);
589
590        ggml_tensor * value =
591            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
592                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
593                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
594        cb(value, "v", il);
595
596        ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
597                                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
598                                    (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
599        z = ggml_cont(ctx0, z);
600        cb(z, "z", il);
601
602        // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
603        // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
604        ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
605        cb(query_flat, "query_flat", il);
606
607        // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
608        ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
609        cb(key_flat, "key_flat", il);
610
611        // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
612        ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
613        cb(value_flat, "value_flat", il);
614
615        // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
616        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
617        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
618        cb(qkv_mixed, "qkv_mixed", il);
619
620        return { qkv_mixed, z };
621    }
622}
623
624ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
625        llm_graph_input_rs * inp,
626        ggml_tensor *        cur,
627        ggml_tensor *        causal_mask,
628        ggml_tensor *        identity,
629        ggml_tensor *        diag_mask,
630        int                  il) {
631    const auto * mctx_cur = inp->mctx;
632
633    const int64_t d_inner      = hparams.ssm_d_inner;
634    const int64_t n_seqs       = ubatch.n_seqs;
635    const int64_t head_k_dim   = hparams.ssm_d_state;
636    const int64_t num_k_heads  = hparams.ssm_n_group;
637    const int64_t num_v_heads  = hparams.ssm_dt_rank;
638    const int64_t head_v_dim   = d_inner / num_v_heads;
639    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
640
641    const auto kv_head = mctx_cur->get_head();
642
643    GGML_ASSERT(n_seqs != 0);
644    GGML_ASSERT(ubatch.equal_seqs());
645    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
646
647    // Input projections
648    auto qkvz = build_qkvz(cur, il);
649    ggml_tensor * qkv_mixed = qkvz.first;
650    ggml_tensor * z         = qkvz.second;
651
652    ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
653    cb(mixed_ba, "linear_attn_mixed_ba", il);
654
655    // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
656    int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
657    ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
658
659    // Split mixed_ba into b and a (beta and alpha parameters)
660    int64_t split_sizes_ba[2] = {
661        num_v_heads / num_k_heads,  // beta size
662        num_v_heads / num_k_heads   // alpha size
663    };
664
665    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
666                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
667    cb(b, "b", il);
668
669    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
670                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
671                                   split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
672    cb(a, "a", il);
673
674    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
675
676    // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
677    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
678
679    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
680    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
681    cb(alpha_softplus, "a_softplus", il);
682    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
683    cb(gate, "gate", il);
684
685    // Get convolution states from cache
686    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
687    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
688
689    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
690
691    // Build the convolution states tensor
692    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
693    cb(conv_states, "conv_states", il);
694
695    // Calculate convolution kernel size
696    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
697    const int64_t conv_kernel_size = conv_kernel->ne[0];
698    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
699    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
700    cb(conv_states, "conv_states_reshaped", il);
701
702    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
703    cb(qkv_mixed, "qkv_mixed_permuted", il);
704
705    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
706    cb(conv_input, "conv_input", il);
707
708    // Update convolution state cache
709    // Extract the last (conv_kernel_size - 1) states from conv_input
710    ggml_tensor * last_conv_states =
711        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
712                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
713    cb(last_conv_states, "last_conv_states", il);
714
715    ggml_tensor * state_update_target =
716        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
717                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
718    cb(state_update_target, "state_update_target", il);
719
720    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
721    cb(conv_states_all, "conv_states_updated", il);
722
723    // Apply SSM convolution
724    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
725    cb(conv_output_proper, "conv_output_raw", il);
726
727    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
728    cb(conv_output_silu, "conv_output_silu", il);
729
730    ggml_tensor * conv_qkv_mix = conv_output_silu;
731
732    // Calculate the total conv dimension
733    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
734    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
735
736    // Extract the convolved Q, K, V from conv_output
737    ggml_tensor * q_conv =
738        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
739    cb(q_conv, "q_conv", il);
740    ggml_tensor * k_conv =
741        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
742                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
743    cb(k_conv, "k_conv", il);
744    ggml_tensor * v_conv =
745        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
746                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
747    cb(v_conv, "v_conv", il);
748
749    // Unsqueeze them
750    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
751    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
752    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
753
754    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
755    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
756    cb(state, "state_predelta", il);
757
758    // if head keys and value keys are different, repeat to force tensors into matching shapes
759    if (num_k_heads != num_v_heads) {
760        GGML_ASSERT(num_v_heads % num_k_heads == 0);
761        int64_t repeat_factor = num_v_heads / num_k_heads;
762
763        // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
764        ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
765        ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
766
767        // Repeat along the third dimension (the new dimension with size 1)
768        ggml_tensor * q_repeated =
769            ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
770        ggml_tensor * k_repeated =
771            ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
772
773        // Reshape back to merge the head and repeat dimensions
774        // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
775        // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
776        q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
777        k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
778    }
779
780    cb(q_conv, "q_conv_predelta", il);
781    cb(k_conv, "k_conv_predelta", il);
782    cb(v_conv, "v_conv_predelta", il);
783
784    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
785    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
786    if (n_seq_tokens == 1) {
787        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
788    } else {
789        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
790    }
791    ggml_tensor * output    = attn_out.first;
792    ggml_tensor * new_state = attn_out.second;
793    cb(output, "attn_output", il);
794    cb(new_state, "new_state", il);
795
796    // Update the recurrent states
797    ggml_build_forward_expand(gf,
798                              ggml_cpy(ctx0, new_state,
799                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
800                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
801
802    // Reshape both attn_out_final and z to 2D tensors for normalization
803    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
804    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
805
806    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
807    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
808
809    // Apply gated normalization: self.norm(core_attn_out, z)
810    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
811
812    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
813    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
814    cb(final_output, "final_output", il);
815
816    // Output projection
817    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
818    cb(cur, "linear_attn_out", il);
819
820    // Reshape back to original dimensions
821    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
822    return cur;
823}
824
825ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
826    // Check if this is an MoE layer
827    if (model.layers[il].ffn_gate_inp != nullptr) {
828        // MoE branch
829        ggml_tensor * moe_out =
830            build_moe_ffn(cur,
831                model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
832                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
833                nullptr,
834                n_expert, n_expert_used, LLM_FFN_SILU,
835                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
836        cb(moe_out, "ffn_moe_out", il);
837
838        // Add shared experts if present - following Qwen3Next reference implementation
839        if (model.layers[il].ffn_up_shexp != nullptr) {
840            ggml_tensor * ffn_shexp =
841                build_ffn(cur,
842                    model.layers[il].ffn_up_shexp, NULL, NULL,
843                    model.layers[il].ffn_gate_shexp, NULL, NULL,
844                    model.layers[il].ffn_down_shexp, NULL, NULL,
845                    NULL,
846                    LLM_FFN_SILU, LLM_FFN_PAR, il);
847            cb(ffn_shexp, "ffn_shexp", il);
848
849            // Apply shared expert gating as in the reference implementation
850            // The shared expert has its own gate that is sigmoided
851            // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
852            ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
853            cb(shared_gate, "shared_expert_gate", il);
854
855            // Apply sigmoid to the gate
856            shared_gate = ggml_sigmoid(ctx0, shared_gate);
857            cb(shared_gate, "shared_expert_gate_sigmoid", il);
858
859            // Apply the gate to the shared expert output
860            ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
861            cb(ffn_shexp, "ffn_shexp_gated", il);
862
863            cur = ggml_add(ctx0, moe_out, ffn_shexp);
864            cb(cur, "ffn_out", il);
865        } else {
866            cur = moe_out;
867        }
868    } else {
869        // Dense FFN branch (not currently used I believe)
870        cur = build_ffn(cur,
871            model.layers[il].ffn_up, NULL, NULL,
872            model.layers[il].ffn_gate, NULL, NULL,
873            model.layers[il].ffn_down, NULL, NULL,
874            NULL,
875            LLM_FFN_SILU, LLM_FFN_PAR, il);
876        cb(cur, "ffn_out", il);
877    }
878    return cur;
879}