aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/src/models/gemma3n-iswa.cpp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/src/models/gemma3n-iswa.cpp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/src/models/gemma3n-iswa.cpp')
-rw-r--r--llama.cpp/src/models/gemma3n-iswa.cpp384
1 files changed, 384 insertions, 0 deletions
diff --git a/llama.cpp/src/models/gemma3n-iswa.cpp b/llama.cpp/src/models/gemma3n-iswa.cpp
new file mode 100644
index 0000000..7db6d3b
--- /dev/null
+++ b/llama.cpp/src/models/gemma3n-iswa.cpp
@@ -0,0 +1,384 @@
1#include "models.h"
2
3llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
4 llm_graph_context(params),
5 model(model),
6 n_embd_head(model.hparams.n_embd_head_k),
7 n_embd_altup(model.hparams.n_embd_altup),
8 n_altup(model.hparams.n_altup),
9 i_altup_act(model.hparams.i_altup_act) {
10 ggml_tensor * cur;
11 ggml_tensor * inpL;
12
13 inpL = build_inp_embd(model.tok_embd);
14
15 // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
16 inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
17 cb(inpL, "inp_scaled", -1);
18
19 // inp_pos - contains the positions
20 ggml_tensor * inp_pos = build_inp_pos();
21
22 // TODO: is causal == true correct? might need some changes
23 auto * inp_attn = build_attn_inp_kv_iswa();
24
25 // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
26 ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
27
28 // inpL now has only 1 altup, project it to the rest of the altups
29 // these "added" altups will be concat to the last dim of inpL
30 {
31 ggml_tensor * target_magnitude = calc_magnitude(inpL);
32 ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
33 ggml_tensor * altup_added =
34 ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
35 ggml_tensor * new_magnitude = calc_magnitude(altup_added);
36 altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude);
37 inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
38 cb(inpL, "inp_stacked", -1);
39 }
40 // inpL now has shape: [n_embd, n_tokens, n_altup]
41 // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
42
43 for (int il = 0; il < n_layer; ++il) {
44 // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
45 const float freq_base_l = model.get_rope_freq_base(cparams, il);
46 const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
47
48 ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
49 ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
50
51 // predicted value will go through self-attention and laurel
52 ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
53 cur = active_prediction;
54 cb(cur, "active_prediction", il);
55
56 // norm
57 cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
58 cb(cur, "attn_norm", il);
59
60 // laurel
61 ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
62
63 // self-attention
64 if (hparams.has_kv(il)) {
65 // compute Q and K and RoPE them
66 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
67 cb(Qcur, "Qcur", il);
68
69 ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
70 cb(Kcur, "Kcur", il);
71
72 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
73 cb(Vcur, "Vcur", il);
74
75 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
77 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
78
79 Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
80 Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
81 Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
82
83 cb(Qcur, "Qcur_normed", il);
84 cb(Kcur, "Kcur_normed", il);
85 cb(Vcur, "Vcur_normed", il);
86
87 Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
88 ext_factor, attn_factor, beta_fast, beta_slow);
89
90 Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
91 ext_factor, attn_factor, beta_fast, beta_slow);
92
93 cb(Qcur, "Qcur_pos", il);
94 cb(Kcur, "Kcur_pos", il);
95
96 cur = build_attn(inp_attn, model.layers[il].wo,
97 NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
98 hparams.f_attention_scale, il);
99 } else {
100 // reuse KV cache of earlier layers
101 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
102 cb(Qcur, "Qcur", il);
103 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
104
105 Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
106 cb(Qcur, "Qcur_normed", il);
107
108 Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
109 ext_factor, attn_factor, beta_fast, beta_slow);
110 cb(Qcur, "Qcur_pos", il);
111
112 cur = build_attn(inp_attn,
113 model.layers[il].wo, NULL,
114 Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
115 }
116 cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
117 cb(cur, "attn_post_norm", il);
118
119 cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
120 cb(cur, "attn_gated", il);
121
122 ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out),
123 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
124 cb(attn_laurel, "attn_laurel", il);
125
126 cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
127 cb(cur, "ffn_norm", il);
128
129 // feed-forward network
130 {
131 ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
132 ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
133
134 if (il < n_layer_sparsity) {
135 // apply activation sparsity
136 gate_proj = gaussian_topk(gate_proj);
137 }
138 gate_proj = ggml_gelu(ctx0, gate_proj);
139
140 cur = ggml_mul(ctx0, up_proj, gate_proj);
141 cur = build_lora_mm(model.layers[il].ffn_down, cur);
142 cb(cur, "ffn_out", il);
143 }
144 cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1);
145 cb(cur, "ffn_post_norm", il);
146
147 ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
148 cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
149
150 ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
151
152 ggml_tensor * first_prediction; // [n_embd, n_tokens]
153 {
154 first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
155 first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
156 first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
157 first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
158 cb(first_prediction, "first_prediction_gated", il);
159 ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
160 first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
161 cb(first_prediction, "first_prediction_scaled", il);
162
163 first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
164 first_prediction =
165 build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il);
166 cb(first_prediction, "first_prediction_out", il);
167 }
168 // equivalent to python code: corrected_predictions[1:] += first_prediction
169 {
170 ggml_tensor * slice_first = view_2d_slice(corrected, 0);
171 ggml_tensor * slice_rest = ggml_view_3d(
172 ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
173 ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
174 ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
175 corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
176 }
177 cur = corrected; // [n_embd, n_tokens, n_altup]
178 cur = build_cvec(cur, il);
179 cb(cur, "l_out", il);
180
181 // input for next layer
182 inpL = cur;
183 }
184 cur = inpL; // [n_embd, n_tokens, n_altup]
185
186 // cur now has multiple altup(s), we want to merge them back to 1 altup
187 {
188 ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
189 // do a view to skip the first slice (active altup)
190 ggml_tensor * alt_slice =
191 ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
192 ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur));
193 ggml_tensor * altup_unembd =
194 ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
195 ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
196 altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude);
197 cb(altup_unembd, "altup_unembd", -1);
198
199 // equivalent to torch.mean(hidden_states, dim=0)
200 cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
201 for (int i = 0; i < n_altup - 1; ++i) {
202 cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
203 }
204 cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
205 cb(cur, "unembd_merged", -1);
206 }
207 // cur now has shape: [n_embd, n_tokens]
208
209 // TODO: move this to right after the last KV layer
210 {
211 // skip computing output for unused tokens
212 ggml_tensor * inp_out_ids = build_inp_out_ids();
213 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
214 }
215 cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
216
217 cb(cur, "result_norm", -1);
218 res->t_embd = cur;
219
220 cur = build_lora_mm(model.output, cur);
221
222 {
223 // final logit soft-capping
224 cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
225 cur = ggml_tanh(ctx0, cur);
226 cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
227 }
228 cb(cur, "result_output", -1);
229 res->t_logits = cur;
230
231 ggml_build_forward_expand(gf, cur);
232}
233
234ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
235 return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
236}
237
238// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
239ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
240 GGML_ASSERT(idx < (int) x->ne[2]);
241 return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
242 idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
243}
244
245// equivalent to get_per_layer_inputs() in python code
246// output shape: [n_embd_altup, n_layer, n_tokens]
247ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
248 auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
249 ggml_tensor * inp_per_layer;
250 if (ubatch.token) {
251 inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
252 ggml_set_input(inp->tokens);
253 res->t_inp_tokens = inp->tokens;
254 inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
255 inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
256 inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
257 cb(inp_per_layer, "inp_per_layer_selected", -1);
258 res->add_input(std::move(inp));
259 } else {
260 // Vision embedding path: use padding token (ID=0) embedding
261 // TODO: verify if this is the correct behavior in transformers implementation
262 const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
263
264 // Extract and dequantize padding token embedding (row 0)
265 ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
266 inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
267
268 // Reshape to [n_embd_altup, n_layer, 1]
269 inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
270 cb(inp_per_layer, "inp_per_layer_vision", -1);
271 }
272 return inp_per_layer;
273}
274
275// equivalent to project_per_layer_inputs() in python code
276// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
277// output shape: [n_embd_altup, n_tokens, n_layer]
278ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
279 const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
280 const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
281
282 ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
283 per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
284 per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
285 per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
286 -1); // [n_embd_altup, n_layer, n_tokens]
287 cb(per_layer_proj, "per_layer_proj", -1);
288
289 inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
290 inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
291 cb(inp_per_layer, "inp_per_layer", -1);
292
293 // permute to shape: [n_embd_altup, n_tokens, n_layer]
294 inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
295 return inp_per_layer;
296}
297
298// input cur shape: [n_altup, n_tokens]
299// output shape: [n_altup, n_tokens]
300ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) {
301 ggml_tensor * tmp = cur;
302 tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
303 tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
304 tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
305 tmp = ggml_add(ctx0, tmp, cur);
306 cb(tmp, "laurel_out", il);
307 return tmp;
308}
309
310// input x shape: [n_embd, n_tokens]
311// output shape: [n_embd, n_tokens]
312ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) {
313 ggml_tensor * mean = ggml_mean(ctx0, x);
314 ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
315 1.0f / (float) (x->ne[0] - 1)));
316 ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
317 return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
318}
319
320//
321// altup functions
322//
323
324// equivalent to compute_router_modalities() in python code
325// input x shape: [n_embd, n_tokens]
326// output shape: [n_altup, n_tokens]
327ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) {
328 ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il);
329
330 // router_input_scale
331 router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd);
332
333 ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
334 return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
335}
336
337// input cur shape: [n_embd, n_tokens, n_altup]
338// output shape: [n_embd, n_tokens, n_altup]
339ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
340 ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
341 ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
342 cb(modalities, "modalities", il);
343
344 ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
345 cb(all_coefs, "all_coefs", il);
346 // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
347 all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
348
349 // permute to [n_altup, n_embd, n_tokens]
350 ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
351 ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
352
353 // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
354 predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
355 predictions = ggml_add(ctx0, predictions, cur);
356 cb(predictions, "predictions", il);
357
358 return predictions;
359}
360
361// input predictions shape: [n_embd, n_tokens, n_altup]
362// input activated shape: [n_embd, n_tokens]
363// output shape: [n_embd, n_tokens, n_altup]
364ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
365 ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
366 cb(modalities, "modalities", il);
367
368 ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
369 ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
370 cb(innovation, "innovation", il);
371
372 ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
373 all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0
374 cb(all_coefs, "all_coefs", il);
375 all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup]
376 all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
377
378 innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
379 ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
380 corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
381 cb(corrected, "corrected", il);
382
383 return corrected;
384}