1#include "models.h"
  2
  3llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
  4    llm_graph_context(params),
  5    model(model),
  6    n_embd_head(model.hparams.n_embd_head_k),
  7    n_embd_altup(model.hparams.n_embd_altup),
  8    n_altup(model.hparams.n_altup),
  9    i_altup_act(model.hparams.i_altup_act) {
 10    ggml_tensor * cur;
 11    ggml_tensor * inpL;
 12
 13    inpL = build_inp_embd(model.tok_embd);
 14
 15    // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
 16    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
 17    cb(inpL, "inp_scaled", -1);
 18
 19    // inp_pos - contains the positions
 20    ggml_tensor * inp_pos = build_inp_pos();
 21
 22    // TODO: is causal == true correct? might need some changes
 23    auto * inp_attn = build_attn_inp_kv_iswa();
 24
 25    // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
 26    ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
 27
 28    // inpL now has only 1 altup, project it to the rest of the altups
 29    // these "added" altups will be concat to the last dim of inpL
 30    {
 31        ggml_tensor * target_magnitude = calc_magnitude(inpL);
 32        ggml_tensor * inp_repeated     = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
 33        ggml_tensor * altup_added =
 34            ggml_mul_mat(ctx0, model.altup_proj, inp_repeated);  // shape: [n_embd, n_tokens, n_altup - 1]
 35        ggml_tensor * new_magnitude = calc_magnitude(altup_added);
 36        altup_added                 = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude);
 37        inpL                        = ggml_concat(ctx0, inpL, altup_added, 2);  // shape: [n_embd, n_tokens, n_altup]
 38        cb(inpL, "inp_stacked", -1);
 39    }
 40    // inpL now has shape:          [n_embd,       n_tokens, n_altup]
 41    // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
 42
 43    for (int il = 0; il < n_layer; ++il) {
 44        // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
 45        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
 46        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 47
 48        ggml_tensor * cur         = inpL;                    // [n_embd, n_tokens, n_altup]
 49        ggml_tensor * predictions = altup_predict(cur, il);  // [n_embd, n_tokens, n_altup]
 50
 51        // predicted value will go through self-attention and laurel
 52        ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);  // [n_embd, n_tokens]
 53        cur                             = active_prediction;
 54        cb(cur, "active_prediction", il);
 55
 56        // norm
 57        cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
 58        cb(cur, "attn_norm", il);
 59
 60        // laurel
 61        ggml_tensor * laurel_out = laurel(cur, il);  // [n_embd, n_tokens]
 62
 63        // self-attention
 64        if (hparams.has_kv(il)) {
 65            // compute Q and K and RoPE them
 66            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 67            cb(Qcur, "Qcur", il);
 68
 69            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
 70            cb(Kcur, "Kcur", il);
 71
 72            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
 73            cb(Vcur, "Vcur", il);
 74
 75            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 76            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 77            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 78
 79            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
 80            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
 81            Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
 82
 83            cb(Qcur, "Qcur_normed", il);
 84            cb(Kcur, "Kcur_normed", il);
 85            cb(Vcur, "Vcur_normed", il);
 86
 87            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
 88                                 ext_factor, attn_factor, beta_fast, beta_slow);
 89
 90            Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
 91                                 ext_factor, attn_factor, beta_fast, beta_slow);
 92
 93            cb(Qcur, "Qcur_pos", il);
 94            cb(Kcur, "Kcur_pos", il);
 95
 96            cur = build_attn(inp_attn, model.layers[il].wo,
 97                    NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
 98                    hparams.f_attention_scale, il);
 99        } else {
100            // reuse KV cache of earlier layers
101            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
102            cb(Qcur, "Qcur", il);
103            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
104
105            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
106            cb(Qcur, "Qcur_normed", il);
107
108            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
109                                 ext_factor, attn_factor, beta_fast, beta_slow);
110            cb(Qcur, "Qcur_pos", il);
111
112            cur = build_attn(inp_attn,
113                    model.layers[il].wo, NULL,
114                    Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
115        }
116        cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
117        cb(cur, "attn_post_norm", il);
118
119        cur = ggml_add(ctx0, cur, active_prediction);  // [n_embd, n_tokens]
120        cb(cur, "attn_gated", il);
121
122        ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out),
123                                               1.0f / sqrtf(2.0f));  // [n_embd, n_tokens]
124        cb(attn_laurel, "attn_laurel", il);
125
126        cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
127        cb(cur, "ffn_norm", il);
128
129        // feed-forward network
130        {
131            ggml_tensor * up_proj   = build_lora_mm(model.layers[il].ffn_up, cur);
132            ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
133
134            if (il < n_layer_sparsity) {
135                // apply activation sparsity
136                gate_proj = gaussian_topk(gate_proj);
137            }
138            gate_proj = ggml_gelu(ctx0, gate_proj);
139
140            cur = ggml_mul(ctx0, up_proj, gate_proj);
141            cur = build_lora_mm(model.layers[il].ffn_down, cur);
142            cb(cur, "ffn_out", il);
143        }
144        cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1);
145        cb(cur, "ffn_post_norm", il);
146
147        ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel);  // [n_embd, n_tokens]
148        cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
149
150        ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il);  // [n_embd, n_tokens, n_altup]
151
152        ggml_tensor * first_prediction;                                                   // [n_embd, n_tokens]
153        {
154            first_prediction = view_2d_slice(corrected, i_altup_act);                     // [n_embd, n_tokens]
155            first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
156            first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
157            first_prediction = ggml_gelu(ctx0, first_prediction);                 // [n_embd_altup, n_tokens]
158            cb(first_prediction, "first_prediction_gated", il);
159            ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il);      // [n_embd_altup, n_tokens]
160            first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer);  // [n_embd_altup, n_tokens]
161            cb(first_prediction, "first_prediction_scaled", il);
162
163            first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction);  // [n_embd, n_tokens]
164            first_prediction =
165                build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il);
166            cb(first_prediction, "first_prediction_out", il);
167        }
168        // equivalent to python code: corrected_predictions[1:] += first_prediction
169        {
170            ggml_tensor * slice_first = view_2d_slice(corrected, 0);
171            ggml_tensor * slice_rest  = ggml_view_3d(
172                ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
173                ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
174            ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction);  // [n_embd, n_tokens, n_altup - 1]
175            corrected         = ggml_concat(ctx0, slice_first, tmp, 2);        // [n_embd, n_tokens, n_altup]
176        }
177        cur = corrected;                                                       // [n_embd, n_tokens, n_altup]
178        cur = build_cvec(cur, il);
179        cb(cur, "l_out", il);
180
181        // input for next layer
182        inpL = cur;
183    }
184    cur = inpL;  // [n_embd, n_tokens, n_altup]
185
186    // cur now has multiple altup(s), we want to merge them back to 1 altup
187    {
188        ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act));  // [n_embd, n_tokens]
189        // do a view to skip the first slice (active altup)
190        ggml_tensor * alt_slice =
191            ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
192                         ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur));
193        ggml_tensor * altup_unembd =
194            ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice);  // shape: [n_embd, n_tokens, n_altup - 1]
195        ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
196        altup_unembd                = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude);
197        cb(altup_unembd, "altup_unembd", -1);
198
199        // equivalent to torch.mean(hidden_states, dim=0)
200        cur = view_2d_slice(cur, 0);  // [n_embd, n_tokens]
201        for (int i = 0; i < n_altup - 1; ++i) {
202            cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
203        }
204        cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup));  // [n_embd, n_tokens]
205        cb(cur, "unembd_merged", -1);
206    }
207    // cur now has shape: [n_embd, n_tokens]
208
209    // TODO: move this to right after the last KV layer
210    {
211        // skip computing output for unused tokens
212        ggml_tensor * inp_out_ids = build_inp_out_ids();
213        cur                       = ggml_get_rows(ctx0, cur, inp_out_ids);
214    }
215    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
216
217    cb(cur, "result_norm", -1);
218    res->t_embd = cur;
219
220    cur = build_lora_mm(model.output, cur);
221
222    {
223        // final logit soft-capping
224        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
225        cur = ggml_tanh(ctx0, cur);
226        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
227    }
228    cb(cur, "result_output", -1);
229    res->t_logits = cur;
230
231    ggml_build_forward_expand(gf, cur);
232}
233
234ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
235    return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
236}
237
238// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
239ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
240    GGML_ASSERT(idx < (int) x->ne[2]);
241    return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
242                        idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
243}
244
245// equivalent to get_per_layer_inputs() in python code
246// output shape: [n_embd_altup, n_layer, n_tokens]
247ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
248    auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
249    ggml_tensor * inp_per_layer;
250    if (ubatch.token) {
251        inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
252        ggml_set_input(inp->tokens);
253        res->t_inp_tokens = inp->tokens;
254        inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
255        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
256        inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
257        cb(inp_per_layer, "inp_per_layer_selected", -1);
258        res->add_input(std::move(inp));
259    } else {
260        // Vision embedding path: use padding token (ID=0) embedding
261        // TODO: verify if this is the correct behavior in transformers implementation
262        const int64_t embd_size = model.tok_embd_per_layer->ne[0];  // n_embd_altup * n_layer
263
264        // Extract and dequantize padding token embedding (row 0)
265        ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
266        inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
267
268        // Reshape to [n_embd_altup, n_layer, 1]
269        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
270        cb(inp_per_layer, "inp_per_layer_vision", -1);
271    }
272    return inp_per_layer;
273}
274
275// equivalent to project_per_layer_inputs() in python code
276// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
277// output shape: [n_embd_altup, n_tokens, n_layer]
278ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
279    const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
280    const float per_layer_input_scale      = 1.0f / sqrtf(2.0f);
281
282    ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
283    per_layer_proj               = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
284    per_layer_proj               = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
285    per_layer_proj               = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
286                                              -1);  // [n_embd_altup, n_layer, n_tokens]
287    cb(per_layer_proj, "per_layer_proj", -1);
288
289    inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
290    inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
291    cb(inp_per_layer, "inp_per_layer", -1);
292
293    // permute to shape: [n_embd_altup, n_tokens, n_layer]
294    inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
295    return inp_per_layer;
296}
297
298// input cur shape: [n_altup, n_tokens]
299// output    shape: [n_altup, n_tokens]
300ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) {
301    ggml_tensor * tmp = cur;
302    tmp               = build_lora_mm(model.layers[il].laurel_l, tmp);
303    tmp               = build_lora_mm(model.layers[il].laurel_r, tmp);
304    tmp               = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
305    tmp               = ggml_add(ctx0, tmp, cur);
306    cb(tmp, "laurel_out", il);
307    return tmp;
308}
309
310// input x shape: [n_embd, n_tokens]
311// output  shape: [n_embd, n_tokens]
312ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) {
313    ggml_tensor * mean = ggml_mean(ctx0, x);
314    ggml_tensor * std  = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
315                                                    1.0f / (float) (x->ne[0] - 1)));
316    ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
317    return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
318}
319
320//
321// altup functions
322//
323
324// equivalent to compute_router_modalities() in python code
325// input x shape: [n_embd,  n_tokens]
326// output  shape: [n_altup, n_tokens]
327ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) {
328    ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il);
329
330    // router_input_scale
331    router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd);
332
333    ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
334    return ggml_tanh(ctx0, output);  // [n_altup, n_tokens]
335}
336
337// input cur shape: [n_embd, n_tokens, n_altup]
338// output    shape: [n_embd, n_tokens, n_altup]
339ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
340    ggml_tensor * activated  = view_2d_slice(cur, i_altup_act);                 // [n_embd, n_tokens]
341    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
342    cb(modalities, "modalities", il);
343
344    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
345    cb(all_coefs, "all_coefs", il);
346    // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
347    all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
348
349    // permute to [n_altup, n_embd, n_tokens]
350    ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
351    ggml_tensor * predictions  = ggml_mul_mat(ctx0, cur_permuted, all_coefs);  // [n_altup, n_embd, n_tokens]
352
353    // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
354    predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
355    predictions = ggml_add(ctx0, predictions, cur);
356    cb(predictions, "predictions", il);
357
358    return predictions;
359}
360
361// input predictions       shape: [n_embd, n_tokens, n_altup]
362// input activated         shape: [n_embd, n_tokens]
363// output                  shape: [n_embd, n_tokens, n_altup]
364ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
365    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
366    cb(modalities, "modalities", il);
367
368    ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
369    ggml_tensor * innovation        = ggml_sub(ctx0, activated, active_prediction);  // [n_embd, n_tokens]
370    cb(innovation, "innovation", il);
371
372    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities);  // [n_altup, n_tokens]
373    all_coefs               = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f);                    // + 1.0
374    cb(all_coefs, "all_coefs", il);
375    all_coefs = ggml_transpose(ctx0, all_coefs);                                               // [n_tokens, n_altup]
376    all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup);                           // [1, n_tokens, n_altup]
377
378    innovation              = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
379    ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs);   // [n_embd, n_tokens, n_altup]
380    corrected               = ggml_add(ctx0, corrected, predictions);  // [n_embd, n_tokens, n_altup]
381    cb(corrected, "corrected", il);
382
383    return corrected;
384}