1#include "models.h"
  2
  3llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
  4    const int64_t n_embd_head = hparams.n_embd_head_v;
  5    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
  6
  7    ggml_tensor * cur;
  8    ggml_tensor * inpL;
  9
 10    inpL = build_inp_embd(model.tok_embd);
 11
 12    // MuP scaling: embeddings * sqrt(hidden_size)
 13    // mup_enabled = true, hidden_size = 1024, scale = 32.0
 14    inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
 15    cb(inpL, "inp_embd_scaled", -1);
 16
 17    // inp_pos - contains the positions
 18    ggml_tensor * inp_pos = build_inp_pos();
 19    auto * inp_attn = build_attn_inp_kv_iswa();
 20    ggml_tensor * inp_out_ids = build_inp_out_ids();
 21
 22    const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
 23
 24    for (int il = 0; il < n_layer; ++il) {
 25        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
 26        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 27
 28        ggml_tensor * inpSA = inpL;
 29
 30        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
 31        const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
 32                              (il + 1) % hparams.n_no_rope_layer_step != 0;
 33
 34        // dual attention normalization (pre)
 35        cur = build_norm(inpL,
 36                model.layers[il].attn_norm, NULL,
 37                LLM_NORM_RMS, il);
 38        cb(cur, "attn_norm", il);
 39
 40        // self-attention
 41        {
 42            ggml_tensor * attn_inp = cur;  // save input for gate computation
 43
 44            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 45            cb(Qcur, "Qcur", il);
 46
 47            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
 48            cb(Kcur, "Kcur", il);
 49
 50            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
 51            cb(Vcur, "Vcur", il);
 52
 53            // compute gate from input
 54            ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
 55            cb(gate, "attn_gate_proj", il);
 56
 57            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
 58            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 59
 60            // Q/K normalization
 61            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
 62            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
 63            cb(Qcur, "Qcur_normed", il);
 64            cb(Kcur, "Kcur_normed", il);
 65
 66            if (use_rope) {
 67                Qcur = ggml_rope_ext(
 68                        ctx0, Qcur, inp_pos, nullptr,
 69                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
 70                        ext_factor, attn_factor, beta_fast, beta_slow);
 71                cb(Qcur, "Qcur_rope", il);
 72
 73                Kcur = ggml_rope_ext(
 74                        ctx0, Kcur, inp_pos, nullptr,
 75                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
 76                        ext_factor, attn_factor, beta_fast, beta_slow);
 77                cb(Kcur, "Kcur_rope", il);
 78            }
 79
 80            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 81
 82            cur = build_attn(inp_attn,
 83                    NULL, NULL,  // wo will be applied after gating
 84                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
 85            cb(cur, "attn_out", il);
 86
 87            // attention gating: attn_out * sigmoid(gate) BEFORE o_proj
 88            gate = ggml_sigmoid(ctx0, gate);
 89            cb(gate, "attn_gate_sig", il);
 90            cur = ggml_mul(ctx0, cur, gate);
 91            cb(cur, "attn_gated", il);
 92
 93            // now apply output projection
 94            cur = build_lora_mm(model.layers[il].wo, cur);
 95            cb(cur, "attn_o_proj", il);
 96        }
 97
 98        // dual attention normalization (post)
 99        cur = build_norm(cur,
100                model.layers[il].attn_post_norm, NULL,
101                LLM_NORM_RMS, il);
102        cb(cur, "attn_post_norm", il);
103
104        if (il == n_layer - 1 && inp_out_ids) {
105            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
106            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
107        }
108
109        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
110        cb(ffn_inp, "ffn_inp", il);
111
112        // dual ffn normalization (pre)
113        cur = build_norm(ffn_inp,
114                model.layers[il].ffn_norm, NULL,
115                LLM_NORM_RMS, il);
116        cb(cur, "ffn_norm", il);
117
118        // MoE or dense FFN
119        if ((uint32_t)il >= hparams.n_layer_dense_lead) {
120            // MoE layer with sigmoid routing, normalization, and scaling
121            ggml_tensor * moe_out = build_moe_ffn(cur,
122                    model.layers[il].ffn_gate_inp,
123                    model.layers[il].ffn_up_exps,
124                    model.layers[il].ffn_gate_exps,
125                    model.layers[il].ffn_down_exps,
126                    model.layers[il].ffn_exp_probs_b,
127                    n_expert, n_expert_used,
128                    LLM_FFN_SILU,
129                    hparams.expert_weights_norm,           // norm_w (route_norm=True)
130                    hparams.expert_weights_scale,          // scale_w
131                    hparams.expert_weights_scale,          // w_scale (route_scale=2.826)
132                    (llama_expert_gating_func_type) hparams.expert_gating_func,
133                    il);
134            cb(moe_out, "ffn_moe_out", il);
135
136            // shared expert
137            if (hparams.n_expert_shared > 0) {
138                ggml_tensor * ffn_shexp = build_ffn(cur,
139                        model.layers[il].ffn_up_shexp,   NULL, NULL,
140                        model.layers[il].ffn_gate_shexp, NULL, NULL,
141                        model.layers[il].ffn_down_shexp, NULL, NULL,
142                        NULL,
143                        LLM_FFN_SILU, LLM_FFN_PAR, il);
144                cb(ffn_shexp, "ffn_shexp", il);
145
146                cur = ggml_add(ctx0, moe_out, ffn_shexp);
147                cb(cur, "ffn_out", il);
148            } else {
149                cur = moe_out;
150            }
151        } else {
152            // dense layer
153            cur = build_ffn(cur,
154                    model.layers[il].ffn_up,   NULL, NULL,
155                    model.layers[il].ffn_gate, NULL, NULL,
156                    model.layers[il].ffn_down, NULL, NULL,
157                    NULL,
158                    LLM_FFN_SILU, LLM_FFN_PAR, il);
159            cb(cur, "ffn_out", il);
160        }
161
162        // dual ffn normalization (post)
163        cur = build_norm(cur,
164                model.layers[il].ffn_post_norm, NULL,
165                LLM_NORM_RMS, il);
166        cb(cur, "ffn_post_norm", il);
167
168        cur = ggml_add(ctx0, cur, ffn_inp);
169        cur = build_cvec(cur, il);
170        cb(cur, "l_out", il);
171
172        // input for next layer
173        inpL = cur;
174    }
175
176    cur = inpL;
177
178    cur = build_norm(cur,
179            model.output_norm, NULL,
180            LLM_NORM_RMS, -1);
181    cb(cur, "result_norm", -1);
182
183    res->t_embd = cur;
184
185    // lm_head
186    cur = build_lora_mm(model.output, cur);
187    cb(cur, "result_output", -1);
188    res->t_logits = cur;
189
190    ggml_build_forward_expand(gf, cur);
191}