1#include "models.h"
2
3llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
4 const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
5
6 const uint32_t n_embd_head_qk_rope = hparams.n_rot;
7 const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
8
9 const uint32_t kv_lora_rank = hparams.n_lora_kv;
10
11 ggml_tensor * cur;
12 ggml_tensor * inpL;
13
14 // {n_embd, n_tokens}
15 inpL = build_inp_embd(model.tok_embd);
16
17 // inp_pos - contains the positions
18 ggml_tensor * inp_pos = build_inp_pos();
19
20 auto * inp_attn = build_attn_inp_kv();
21
22 ggml_tensor * inp_out_ids = build_inp_out_ids();
23
24 for (int il = 0; il < n_layer; ++il) {
25 ggml_tensor * inpSA = inpL;
26
27 // norm
28 cur = build_norm(inpL,
29 model.layers[il].attn_norm, NULL,
30 LLM_NORM_RMS, il);
31 cb(cur, "attn_norm", il);
32
33 // self_attention
34 {
35 ggml_tensor * q = NULL;
36 q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
37 cb(q, "q", il);
38
39 // split into {n_head * n_embd_head_qk_nope, n_tokens}
40 ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
41 ggml_row_size(q->type, hparams.n_embd_head_k),
42 ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
43 0);
44 cb(q_nope, "q_nope", il);
45
46 // and {n_head * n_embd_head_qk_rope, n_tokens}
47 ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
48 ggml_row_size(q->type, hparams.n_embd_head_k),
49 ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
50 ggml_row_size(q->type, n_embd_head_qk_nope));
51 cb(q_pe, "q_pe", il);
52
53 // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
54 ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
55 cb(kv_pe_compresseed, "kv_pe_compresseed", il);
56
57 // split into {kv_lora_rank, n_tokens}
58 ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
59 kv_pe_compresseed->nb[1],
60 0);
61 cb(kv_compressed, "kv_compressed", il);
62
63 // and {n_embd_head_qk_rope, n_tokens}
64 ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
65 kv_pe_compresseed->nb[1],
66 kv_pe_compresseed->nb[1],
67 ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
68 cb(k_pe, "k_pe", il);
69
70 kv_compressed = build_norm(kv_compressed,
71 model.layers[il].attn_kv_a_norm, NULL,
72 LLM_NORM_RMS, il);
73 cb(kv_compressed, "kv_compressed", il);
74
75 // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
76 ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
77 cb(kv, "kv", il);
78
79 // split into {n_head * n_embd_head_qk_nope, n_tokens}
80 ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
81 ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
82 ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
83 0);
84 cb(k_nope, "k_nope", il);
85
86 // and {n_head * n_embd_head_v, n_tokens}
87 ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
88 ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
89 ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
90 ggml_row_size(kv->type, (n_embd_head_qk_nope)));
91 cb(v_states, "v_states", il);
92
93 v_states = ggml_cont(ctx0, v_states);
94 cb(v_states, "v_states", il);
95
96 v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
97 ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
98 0);
99 cb(v_states, "v_states", il);
100
101 q_pe = ggml_rope_ext(
102 ctx0, q_pe, inp_pos, nullptr,
103 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
104 ext_factor, attn_factor, beta_fast, beta_slow
105 );
106 cb(q_pe, "q_pe", il);
107
108 // shared RoPE key
109 k_pe = ggml_rope_ext(
110 ctx0, k_pe, inp_pos, nullptr,
111 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
112 ext_factor, attn_factor, beta_fast, beta_slow
113 );
114 cb(k_pe, "k_pe", il);
115
116 ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
117 cb(q_states, "q_states", il);
118
119 ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
120 cb(k_states, "k_states", il);
121
122 cur = build_attn(inp_attn,
123 model.layers[il].wo, NULL,
124 q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il);
125 }
126 if (il == n_layer - 1 && inp_out_ids) {
127 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
128 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
129 }
130 ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
131 cb(ffn_inp, "ffn_inp", il);
132
133 cur = build_norm(ffn_inp,
134 model.layers[il].ffn_norm, NULL,
135 LLM_NORM_RMS, il);
136 cb(cur, "ffn_norm", il);
137
138 cur = build_ffn(cur,
139 model.layers[il].ffn_up, NULL, NULL,
140 NULL, NULL, NULL,
141 model.layers[il].ffn_down, NULL, NULL,
142 NULL,
143 LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
144 cb(cur, "ffn_out", il);
145
146 cur = ggml_add(ctx0, cur, ffn_inp);
147
148 cur = build_cvec(cur, il);
149 cb(cur, "l_out", il);
150
151 // input for next layer
152 inpL = cur;
153 }
154 cur = inpL;
155
156 cur = build_norm(cur,
157 model.output_norm, NULL,
158 LLM_NORM_RMS, -1);
159
160 cb(cur, "result_norm", -1);
161 res->t_embd = cur;
162
163 cur = build_lora_mm(model.output, cur);
164
165 cb(cur, "result_output", -1);
166 res->t_logits = cur;
167
168 ggml_build_forward_expand(gf, cur);
169}