diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/src/models/gemma3n-iswa.cpp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/src/models/gemma3n-iswa.cpp')
| -rw-r--r-- | llama.cpp/src/models/gemma3n-iswa.cpp | 384 |
1 files changed, 384 insertions, 0 deletions
diff --git a/llama.cpp/src/models/gemma3n-iswa.cpp b/llama.cpp/src/models/gemma3n-iswa.cpp new file mode 100644 index 0000000..7db6d3b --- /dev/null +++ b/llama.cpp/src/models/gemma3n-iswa.cpp | |||
| @@ -0,0 +1,384 @@ | |||
| 1 | #include "models.h" | ||
| 2 | |||
| 3 | llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : | ||
| 4 | llm_graph_context(params), | ||
| 5 | model(model), | ||
| 6 | n_embd_head(model.hparams.n_embd_head_k), | ||
| 7 | n_embd_altup(model.hparams.n_embd_altup), | ||
| 8 | n_altup(model.hparams.n_altup), | ||
| 9 | i_altup_act(model.hparams.i_altup_act) { | ||
| 10 | ggml_tensor * cur; | ||
| 11 | ggml_tensor * inpL; | ||
| 12 | |||
| 13 | inpL = build_inp_embd(model.tok_embd); | ||
| 14 | |||
| 15 | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) | ||
| 16 | inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); | ||
| 17 | cb(inpL, "inp_scaled", -1); | ||
| 18 | |||
| 19 | // inp_pos - contains the positions | ||
| 20 | ggml_tensor * inp_pos = build_inp_pos(); | ||
| 21 | |||
| 22 | // TODO: is causal == true correct? might need some changes | ||
| 23 | auto * inp_attn = build_attn_inp_kv_iswa(); | ||
| 24 | |||
| 25 | // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] | ||
| 26 | ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); | ||
| 27 | |||
| 28 | // inpL now has only 1 altup, project it to the rest of the altups | ||
| 29 | // these "added" altups will be concat to the last dim of inpL | ||
| 30 | { | ||
| 31 | ggml_tensor * target_magnitude = calc_magnitude(inpL); | ||
| 32 | ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); | ||
| 33 | ggml_tensor * altup_added = | ||
| 34 | ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] | ||
| 35 | ggml_tensor * new_magnitude = calc_magnitude(altup_added); | ||
| 36 | altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude); | ||
| 37 | inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] | ||
| 38 | cb(inpL, "inp_stacked", -1); | ||
| 39 | } | ||
| 40 | // inpL now has shape: [n_embd, n_tokens, n_altup] | ||
| 41 | // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] | ||
| 42 | |||
| 43 | for (int il = 0; il < n_layer; ++il) { | ||
| 44 | // this block is made to be closely resemble Gemma3p5DecoderLayer on python code | ||
| 45 | const float freq_base_l = model.get_rope_freq_base(cparams, il); | ||
| 46 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); | ||
| 47 | |||
| 48 | ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] | ||
| 49 | ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] | ||
| 50 | |||
| 51 | // predicted value will go through self-attention and laurel | ||
| 52 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] | ||
| 53 | cur = active_prediction; | ||
| 54 | cb(cur, "active_prediction", il); | ||
| 55 | |||
| 56 | // norm | ||
| 57 | cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); | ||
| 58 | cb(cur, "attn_norm", il); | ||
| 59 | |||
| 60 | // laurel | ||
| 61 | ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] | ||
| 62 | |||
| 63 | // self-attention | ||
| 64 | if (hparams.has_kv(il)) { | ||
| 65 | // compute Q and K and RoPE them | ||
| 66 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); | ||
| 67 | cb(Qcur, "Qcur", il); | ||
| 68 | |||
| 69 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); | ||
| 70 | cb(Kcur, "Kcur", il); | ||
| 71 | |||
| 72 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); | ||
| 73 | cb(Vcur, "Vcur", il); | ||
| 74 | |||
| 75 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); | ||
| 76 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); | ||
| 77 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); | ||
| 78 | |||
| 79 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); | ||
| 80 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); | ||
| 81 | Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); | ||
| 82 | |||
| 83 | cb(Qcur, "Qcur_normed", il); | ||
| 84 | cb(Kcur, "Kcur_normed", il); | ||
| 85 | cb(Vcur, "Vcur_normed", il); | ||
| 86 | |||
| 87 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, | ||
| 88 | ext_factor, attn_factor, beta_fast, beta_slow); | ||
| 89 | |||
| 90 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, | ||
| 91 | ext_factor, attn_factor, beta_fast, beta_slow); | ||
| 92 | |||
| 93 | cb(Qcur, "Qcur_pos", il); | ||
| 94 | cb(Kcur, "Kcur_pos", il); | ||
| 95 | |||
| 96 | cur = build_attn(inp_attn, model.layers[il].wo, | ||
| 97 | NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, | ||
| 98 | hparams.f_attention_scale, il); | ||
| 99 | } else { | ||
| 100 | // reuse KV cache of earlier layers | ||
| 101 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); | ||
| 102 | cb(Qcur, "Qcur", il); | ||
| 103 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); | ||
| 104 | |||
| 105 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); | ||
| 106 | cb(Qcur, "Qcur_normed", il); | ||
| 107 | |||
| 108 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, | ||
| 109 | ext_factor, attn_factor, beta_fast, beta_slow); | ||
| 110 | cb(Qcur, "Qcur_pos", il); | ||
| 111 | |||
| 112 | cur = build_attn(inp_attn, | ||
| 113 | model.layers[il].wo, NULL, | ||
| 114 | Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); | ||
| 115 | } | ||
| 116 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); | ||
| 117 | cb(cur, "attn_post_norm", il); | ||
| 118 | |||
| 119 | cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] | ||
| 120 | cb(cur, "attn_gated", il); | ||
| 121 | |||
| 122 | ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out), | ||
| 123 | 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] | ||
| 124 | cb(attn_laurel, "attn_laurel", il); | ||
| 125 | |||
| 126 | cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); | ||
| 127 | cb(cur, "ffn_norm", il); | ||
| 128 | |||
| 129 | // feed-forward network | ||
| 130 | { | ||
| 131 | ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); | ||
| 132 | ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); | ||
| 133 | |||
| 134 | if (il < n_layer_sparsity) { | ||
| 135 | // apply activation sparsity | ||
| 136 | gate_proj = gaussian_topk(gate_proj); | ||
| 137 | } | ||
| 138 | gate_proj = ggml_gelu(ctx0, gate_proj); | ||
| 139 | |||
| 140 | cur = ggml_mul(ctx0, up_proj, gate_proj); | ||
| 141 | cur = build_lora_mm(model.layers[il].ffn_down, cur); | ||
| 142 | cb(cur, "ffn_out", il); | ||
| 143 | } | ||
| 144 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); | ||
| 145 | cb(cur, "ffn_post_norm", il); | ||
| 146 | |||
| 147 | ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] | ||
| 148 | cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); | ||
| 149 | |||
| 150 | ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] | ||
| 151 | |||
| 152 | ggml_tensor * first_prediction; // [n_embd, n_tokens] | ||
| 153 | { | ||
| 154 | first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] | ||
| 155 | first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); | ||
| 156 | first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); | ||
| 157 | first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] | ||
| 158 | cb(first_prediction, "first_prediction_gated", il); | ||
| 159 | ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] | ||
| 160 | first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] | ||
| 161 | cb(first_prediction, "first_prediction_scaled", il); | ||
| 162 | |||
| 163 | first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] | ||
| 164 | first_prediction = | ||
| 165 | build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il); | ||
| 166 | cb(first_prediction, "first_prediction_out", il); | ||
| 167 | } | ||
| 168 | // equivalent to python code: corrected_predictions[1:] += first_prediction | ||
| 169 | { | ||
| 170 | ggml_tensor * slice_first = view_2d_slice(corrected, 0); | ||
| 171 | ggml_tensor * slice_rest = ggml_view_3d( | ||
| 172 | ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), | ||
| 173 | ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); | ||
| 174 | ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] | ||
| 175 | corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] | ||
| 176 | } | ||
| 177 | cur = corrected; // [n_embd, n_tokens, n_altup] | ||
| 178 | cur = build_cvec(cur, il); | ||
| 179 | cb(cur, "l_out", il); | ||
| 180 | |||
| 181 | // input for next layer | ||
| 182 | inpL = cur; | ||
| 183 | } | ||
| 184 | cur = inpL; // [n_embd, n_tokens, n_altup] | ||
| 185 | |||
| 186 | // cur now has multiple altup(s), we want to merge them back to 1 altup | ||
| 187 | { | ||
| 188 | ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] | ||
| 189 | // do a view to skip the first slice (active altup) | ||
| 190 | ggml_tensor * alt_slice = | ||
| 191 | ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), | ||
| 192 | ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur)); | ||
| 193 | ggml_tensor * altup_unembd = | ||
| 194 | ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] | ||
| 195 | ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); | ||
| 196 | altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude); | ||
| 197 | cb(altup_unembd, "altup_unembd", -1); | ||
| 198 | |||
| 199 | // equivalent to torch.mean(hidden_states, dim=0) | ||
| 200 | cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] | ||
| 201 | for (int i = 0; i < n_altup - 1; ++i) { | ||
| 202 | cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); | ||
| 203 | } | ||
| 204 | cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] | ||
| 205 | cb(cur, "unembd_merged", -1); | ||
| 206 | } | ||
| 207 | // cur now has shape: [n_embd, n_tokens] | ||
| 208 | |||
| 209 | // TODO: move this to right after the last KV layer | ||
| 210 | { | ||
| 211 | // skip computing output for unused tokens | ||
| 212 | ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||
| 213 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); | ||
| 214 | } | ||
| 215 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); | ||
| 216 | |||
| 217 | cb(cur, "result_norm", -1); | ||
| 218 | res->t_embd = cur; | ||
| 219 | |||
| 220 | cur = build_lora_mm(model.output, cur); | ||
| 221 | |||
| 222 | { | ||
| 223 | // final logit soft-capping | ||
| 224 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); | ||
| 225 | cur = ggml_tanh(ctx0, cur); | ||
| 226 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); | ||
| 227 | } | ||
| 228 | cb(cur, "result_output", -1); | ||
| 229 | res->t_logits = cur; | ||
| 230 | |||
| 231 | ggml_build_forward_expand(gf, cur); | ||
| 232 | } | ||
| 233 | |||
| 234 | ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { | ||
| 235 | return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); | ||
| 236 | } | ||
| 237 | |||
| 238 | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim | ||
| 239 | ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { | ||
| 240 | GGML_ASSERT(idx < (int) x->ne[2]); | ||
| 241 | return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), | ||
| 242 | idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); | ||
| 243 | } | ||
| 244 | |||
| 245 | // equivalent to get_per_layer_inputs() in python code | ||
| 246 | // output shape: [n_embd_altup, n_layer, n_tokens] | ||
| 247 | ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { | ||
| 248 | auto inp = std::make_unique<llm_graph_input_embd>(n_embd); | ||
| 249 | ggml_tensor * inp_per_layer; | ||
| 250 | if (ubatch.token) { | ||
| 251 | inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); | ||
| 252 | ggml_set_input(inp->tokens); | ||
| 253 | res->t_inp_tokens = inp->tokens; | ||
| 254 | inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); | ||
| 255 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); | ||
| 256 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); | ||
| 257 | cb(inp_per_layer, "inp_per_layer_selected", -1); | ||
| 258 | res->add_input(std::move(inp)); | ||
| 259 | } else { | ||
| 260 | // Vision embedding path: use padding token (ID=0) embedding | ||
| 261 | // TODO: verify if this is the correct behavior in transformers implementation | ||
| 262 | const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer | ||
| 263 | |||
| 264 | // Extract and dequantize padding token embedding (row 0) | ||
| 265 | ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); | ||
| 266 | inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); | ||
| 267 | |||
| 268 | // Reshape to [n_embd_altup, n_layer, 1] | ||
| 269 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); | ||
| 270 | cb(inp_per_layer, "inp_per_layer_vision", -1); | ||
| 271 | } | ||
| 272 | return inp_per_layer; | ||
| 273 | } | ||
| 274 | |||
| 275 | // equivalent to project_per_layer_inputs() in python code | ||
| 276 | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim | ||
| 277 | // output shape: [n_embd_altup, n_tokens, n_layer] | ||
| 278 | ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { | ||
| 279 | const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); | ||
| 280 | const float per_layer_input_scale = 1.0f / sqrtf(2.0f); | ||
| 281 | |||
| 282 | ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); | ||
| 283 | per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); | ||
| 284 | per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); | ||
| 285 | per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, | ||
| 286 | -1); // [n_embd_altup, n_layer, n_tokens] | ||
| 287 | cb(per_layer_proj, "per_layer_proj", -1); | ||
| 288 | |||
| 289 | inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); | ||
| 290 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); | ||
| 291 | cb(inp_per_layer, "inp_per_layer", -1); | ||
| 292 | |||
| 293 | // permute to shape: [n_embd_altup, n_tokens, n_layer] | ||
| 294 | inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); | ||
| 295 | return inp_per_layer; | ||
| 296 | } | ||
| 297 | |||
| 298 | // input cur shape: [n_altup, n_tokens] | ||
| 299 | // output shape: [n_altup, n_tokens] | ||
| 300 | ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { | ||
| 301 | ggml_tensor * tmp = cur; | ||
| 302 | tmp = build_lora_mm(model.layers[il].laurel_l, tmp); | ||
| 303 | tmp = build_lora_mm(model.layers[il].laurel_r, tmp); | ||
| 304 | tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); | ||
| 305 | tmp = ggml_add(ctx0, tmp, cur); | ||
| 306 | cb(tmp, "laurel_out", il); | ||
| 307 | return tmp; | ||
| 308 | } | ||
| 309 | |||
| 310 | // input x shape: [n_embd, n_tokens] | ||
| 311 | // output shape: [n_embd, n_tokens] | ||
| 312 | ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { | ||
| 313 | ggml_tensor * mean = ggml_mean(ctx0, x); | ||
| 314 | ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), | ||
| 315 | 1.0f / (float) (x->ne[0] - 1))); | ||
| 316 | ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); | ||
| 317 | return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); | ||
| 318 | } | ||
| 319 | |||
| 320 | // | ||
| 321 | // altup functions | ||
| 322 | // | ||
| 323 | |||
| 324 | // equivalent to compute_router_modalities() in python code | ||
| 325 | // input x shape: [n_embd, n_tokens] | ||
| 326 | // output shape: [n_altup, n_tokens] | ||
| 327 | ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { | ||
| 328 | ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); | ||
| 329 | |||
| 330 | // router_input_scale | ||
| 331 | router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd); | ||
| 332 | |||
| 333 | ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); | ||
| 334 | return ggml_tanh(ctx0, output); // [n_altup, n_tokens] | ||
| 335 | } | ||
| 336 | |||
| 337 | // input cur shape: [n_embd, n_tokens, n_altup] | ||
| 338 | // output shape: [n_embd, n_tokens, n_altup] | ||
| 339 | ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { | ||
| 340 | ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] | ||
| 341 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] | ||
| 342 | cb(modalities, "modalities", il); | ||
| 343 | |||
| 344 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); | ||
| 345 | cb(all_coefs, "all_coefs", il); | ||
| 346 | // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) | ||
| 347 | all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); | ||
| 348 | |||
| 349 | // permute to [n_altup, n_embd, n_tokens] | ||
| 350 | ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); | ||
| 351 | ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] | ||
| 352 | |||
| 353 | // final shape must be the same as cur: [n_embd, n_tokens, n_altup] | ||
| 354 | predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); | ||
| 355 | predictions = ggml_add(ctx0, predictions, cur); | ||
| 356 | cb(predictions, "predictions", il); | ||
| 357 | |||
| 358 | return predictions; | ||
| 359 | } | ||
| 360 | |||
| 361 | // input predictions shape: [n_embd, n_tokens, n_altup] | ||
| 362 | // input activated shape: [n_embd, n_tokens] | ||
| 363 | // output shape: [n_embd, n_tokens, n_altup] | ||
| 364 | ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { | ||
| 365 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] | ||
| 366 | cb(modalities, "modalities", il); | ||
| 367 | |||
| 368 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); | ||
| 369 | ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] | ||
| 370 | cb(innovation, "innovation", il); | ||
| 371 | |||
| 372 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] | ||
| 373 | all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 | ||
| 374 | cb(all_coefs, "all_coefs", il); | ||
| 375 | all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] | ||
| 376 | all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] | ||
| 377 | |||
| 378 | innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); | ||
| 379 | ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] | ||
| 380 | corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] | ||
| 381 | cb(corrected, "corrected", il); | ||
| 382 | |||
| 383 | return corrected; | ||
| 384 | } | ||
