diff options
Diffstat (limited to 'llama.cpp/src/models/gemma3.cpp')
| -rw-r--r-- | llama.cpp/src/models/gemma3.cpp | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/llama.cpp/src/models/gemma3.cpp b/llama.cpp/src/models/gemma3.cpp new file mode 100644 index 0000000..dec3fc4 --- /dev/null +++ b/llama.cpp/src/models/gemma3.cpp | |||
| @@ -0,0 +1,155 @@ | |||
| 1 | #include "models.h" | ||
| 2 | |||
| 3 | template <bool iswa> | ||
| 4 | llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { | ||
| 5 | const int64_t n_embd_head = hparams.n_embd_head_k; | ||
| 6 | |||
| 7 | ggml_tensor * cur; | ||
| 8 | ggml_tensor * inpL; | ||
| 9 | |||
| 10 | inpL = build_inp_embd(model.tok_embd); | ||
| 11 | |||
| 12 | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) | ||
| 13 | inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); | ||
| 14 | cb(inpL, "inp_scaled", -1); | ||
| 15 | |||
| 16 | // inp_pos - contains the positions | ||
| 17 | ggml_tensor * inp_pos = build_inp_pos(); | ||
| 18 | |||
| 19 | // TODO: is causal == true correct? might need some changes | ||
| 20 | using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; | ||
| 21 | inp_attn_type * inp_attn = nullptr; | ||
| 22 | |||
| 23 | if constexpr (iswa) { | ||
| 24 | inp_attn = build_attn_inp_kv_iswa(); | ||
| 25 | } else { | ||
| 26 | inp_attn = build_attn_inp_kv(); | ||
| 27 | } | ||
| 28 | |||
| 29 | ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||
| 30 | |||
| 31 | for (int il = 0; il < n_layer; ++il) { | ||
| 32 | float freq_base_l = 0.0f; | ||
| 33 | float freq_scale_l = 0.0f; | ||
| 34 | |||
| 35 | if constexpr (iswa) { | ||
| 36 | freq_base_l = model.get_rope_freq_base (cparams, il); | ||
| 37 | freq_scale_l = model.get_rope_freq_scale(cparams, il); | ||
| 38 | } else { | ||
| 39 | freq_base_l = freq_base; | ||
| 40 | freq_scale_l = freq_scale; | ||
| 41 | } | ||
| 42 | |||
| 43 | // norm | ||
| 44 | cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); | ||
| 45 | cb(cur, "attn_norm", il); | ||
| 46 | |||
| 47 | // self-attention | ||
| 48 | { | ||
| 49 | // compute Q and K and RoPE them | ||
| 50 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); | ||
| 51 | cb(Qcur, "Qcur", il); | ||
| 52 | |||
| 53 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); | ||
| 54 | cb(Kcur, "Kcur", il); | ||
| 55 | |||
| 56 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); | ||
| 57 | cb(Vcur, "Vcur", il); | ||
| 58 | |||
| 59 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); | ||
| 60 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); | ||
| 61 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); | ||
| 62 | |||
| 63 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); | ||
| 64 | cb(Qcur, "Qcur_normed", il); | ||
| 65 | |||
| 66 | Qcur = ggml_rope_ext( | ||
| 67 | ctx0, Qcur, inp_pos, nullptr, | ||
| 68 | n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, | ||
| 69 | ext_factor, attn_factor, beta_fast, beta_slow); | ||
| 70 | |||
| 71 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); | ||
| 72 | cb(Kcur, "Kcur_normed", il); | ||
| 73 | |||
| 74 | Kcur = ggml_rope_ext( | ||
| 75 | ctx0, Kcur, inp_pos, nullptr, | ||
| 76 | n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, | ||
| 77 | ext_factor, attn_factor, beta_fast, beta_slow); | ||
| 78 | |||
| 79 | cb(Qcur, "Qcur", il); | ||
| 80 | cb(Kcur, "Kcur", il); | ||
| 81 | cb(Vcur, "Vcur", il); | ||
| 82 | |||
| 83 | // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 | ||
| 84 | Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); | ||
| 85 | |||
| 86 | cur = build_attn(inp_attn, | ||
| 87 | model.layers[il].wo, NULL, | ||
| 88 | Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); | ||
| 89 | } | ||
| 90 | if (il == n_layer - 1 && inp_out_ids) { | ||
| 91 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); | ||
| 92 | inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); | ||
| 93 | } | ||
| 94 | cur = build_norm(cur, | ||
| 95 | model.layers[il].attn_post_norm, NULL, | ||
| 96 | LLM_NORM_RMS, il); | ||
| 97 | cb(cur, "attn_post_norm", il); | ||
| 98 | |||
| 99 | ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); | ||
| 100 | cb(sa_out, "sa_out", il); | ||
| 101 | |||
| 102 | cur = build_norm(sa_out, | ||
| 103 | model.layers[il].ffn_norm, NULL, | ||
| 104 | LLM_NORM_RMS, il); | ||
| 105 | cb(cur, "ffn_norm", il); | ||
| 106 | |||
| 107 | // feed-forward network | ||
| 108 | { | ||
| 109 | cur = build_ffn(cur, | ||
| 110 | model.layers[il].ffn_up, NULL, NULL, | ||
| 111 | model.layers[il].ffn_gate, NULL, NULL, | ||
| 112 | model.layers[il].ffn_down, NULL, NULL, | ||
| 113 | NULL, | ||
| 114 | LLM_FFN_GELU, LLM_FFN_PAR, il); | ||
| 115 | cb(cur, "ffn_out", il); | ||
| 116 | } | ||
| 117 | cur = build_norm(cur, | ||
| 118 | model.layers[il].ffn_post_norm, NULL, | ||
| 119 | LLM_NORM_RMS, -1); | ||
| 120 | cb(cur, "ffn_post_norm", il); | ||
| 121 | |||
| 122 | cur = ggml_add(ctx0, cur, sa_out); | ||
| 123 | |||
| 124 | cur = build_cvec(cur, il); | ||
| 125 | cb(cur, "l_out", il); | ||
| 126 | |||
| 127 | // input for next layer | ||
| 128 | inpL = cur; | ||
| 129 | } | ||
| 130 | cur = inpL; | ||
| 131 | |||
| 132 | cur = build_norm(cur, | ||
| 133 | model.output_norm, NULL, | ||
| 134 | LLM_NORM_RMS, -1); | ||
| 135 | |||
| 136 | cb(cur, "result_norm", -1); | ||
| 137 | res->t_embd = cur; | ||
| 138 | |||
| 139 | // lm_head | ||
| 140 | cur = build_lora_mm(model.output, cur); | ||
| 141 | |||
| 142 | if (hparams.f_final_logit_softcapping) { | ||
| 143 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); | ||
| 144 | cur = ggml_tanh(ctx0, cur); | ||
| 145 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); | ||
| 146 | } | ||
| 147 | |||
| 148 | cb(cur, "result_output", -1); | ||
| 149 | res->t_logits = cur; | ||
| 150 | |||
| 151 | ggml_build_forward_expand(gf, cur); | ||
| 152 | } | ||
| 153 | |||
| 154 | template struct llm_build_gemma3<false>; | ||
| 155 | template struct llm_build_gemma3<true>; | ||
