summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-graph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/src/llama-graph.cpp')
-rw-r--r--llama.cpp/src/llama-graph.cpp2626
1 files changed, 2626 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-graph.cpp b/llama.cpp/src/llama-graph.cpp
new file mode 100644
index 0000000..bba747d
--- /dev/null
+++ b/llama.cpp/src/llama-graph.cpp
@@ -0,0 +1,2626 @@
1#include "llama-graph.h"
2
3#include "llama-impl.h"
4#include "llama-batch.h"
5#include "llama-cparams.h"
6
7#include "llama-kv-cache.h"
8#include "llama-kv-cache-iswa.h"
9#include "llama-memory-hybrid.h"
10#include "llama-memory-hybrid-iswa.h"
11#include "llama-memory-recurrent.h"
12
13#include <cassert>
14#include <cmath>
15#include <cstring>
16#include <numeric>
17#include <sstream>
18#include <unordered_set>
19
20void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
21 if (ubatch->token) {
22 const int64_t n_tokens = ubatch->n_tokens;
23
24 ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
25 }
26
27 if (ubatch->embd) {
28 GGML_ASSERT(n_embd == embd->ne[0]);
29
30 const int64_t n_tokens = ubatch->n_tokens;
31
32 ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
33 }
34}
35
36bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
37 bool res = true;
38
39 res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
40 res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
41
42 return res;
43}
44
45void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
46 if (ubatch->pos && pos) {
47 const int64_t n_tokens = ubatch->n_tokens;
48
49 if (ubatch->token && n_pos_per_embd == 4) {
50 // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
51 // the 3 first dims are the same, and 4th dim is all 0
52 std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
53 // copy the first dimension
54 for (int i = 0; i < n_tokens; ++i) {
55 pos_data[ i] = ubatch->pos[i];
56 pos_data[ n_tokens + i] = ubatch->pos[i];
57 pos_data[2 * n_tokens + i] = ubatch->pos[i];
58 pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
59 }
60 ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
61 } else {
62 ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
63 }
64 }
65}
66
67bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
68 bool res = true;
69
70 res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
71
72 return res;
73}
74
75void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
76 if (ubatch->pos && attn_scale) {
77 const int64_t n_tokens = ubatch->n_tokens;
78
79 GGML_ASSERT(f_attn_temp_scale != 0.0f);
80 GGML_ASSERT(n_attn_temp_floor_scale != 0);
81
82 std::vector<float> attn_scale_data(n_tokens, 0.0f);
83 for (int i = 0; i < n_tokens; ++i) {
84 const float pos = ubatch->pos[i];
85 attn_scale_data[i] = std::log(
86 std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
87 ) * f_attn_temp_scale + 1.0;
88 }
89
90 ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
91 }
92}
93
94void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
95 if (pos_bucket) {
96 const int64_t n_tokens = ubatch->n_tokens;
97
98 GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
99 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
100
101 int32_t * data = (int32_t *) pos_bucket->data;
102
103 for (int j = 0; j < n_tokens; ++j) {
104 for (int i = 0; i < n_tokens; ++i) {
105 data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
106 }
107 }
108 }
109}
110
111void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
112 if (pos_bucket) {
113 mctx->set_input_pos_bucket(pos_bucket, ubatch);
114 }
115}
116
117void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118 GGML_ASSERT(out_ids);
119
120 const int64_t n_tokens = ubatch->n_tokens;
121
122 GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
123 int32_t * data = (int32_t *) out_ids->data;
124
125 if (n_outputs == n_tokens) {
126 for (int i = 0; i < n_tokens; ++i) {
127 data[i] = i;
128 }
129
130 return;
131 }
132
133 GGML_ASSERT(ubatch->output);
134
135 int n_outputs = 0;
136
137 for (int i = 0; i < n_tokens; ++i) {
138 if (ubatch->output[i]) {
139 data[n_outputs++] = i;
140 }
141 }
142}
143
144bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
145 bool res = true;
146
147 res &= n_outputs == params.n_outputs;
148
149 return res;
150}
151
152void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
153 if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
154 const int64_t n_tokens = ubatch->n_tokens;
155 const int64_t n_seq_tokens = ubatch->n_seq_tokens;
156 const int64_t n_seqs_unq = ubatch->n_seqs_unq;
157
158 GGML_ASSERT(mean);
159 GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
160
161 float * data = (float *) mean->data;
162 memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
163
164 std::vector<uint64_t> sums(n_seqs_unq, 0);
165 for (int i = 0; i < n_tokens; i += n_seq_tokens) {
166 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
167 const llama_seq_id seq_id = ubatch->seq_id[i][s];
168 const int32_t seq_idx = ubatch->seq_idx[seq_id];
169
170 sums[seq_idx] += ubatch->n_seq_tokens;
171 }
172 }
173
174 std::vector<float> div(n_seqs_unq, 0.0f);
175 for (int s = 0; s < n_seqs_unq; ++s) {
176 const uint64_t sum = sums[s];
177 if (sum > 0) {
178 div[s] = 1.0f/float(sum);
179 }
180 }
181
182 for (int i = 0; i < n_tokens; i += n_seq_tokens) {
183 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
184 const llama_seq_id seq_id = ubatch->seq_id[i][s];
185 const int32_t seq_idx = ubatch->seq_idx[seq_id];
186
187 for (int j = 0; j < n_seq_tokens; ++j) {
188 data[seq_idx*n_tokens + i + j] = div[seq_idx];
189 }
190 }
191 }
192 }
193}
194
195void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
196 const int64_t n_tokens = ubatch->n_tokens;
197 const int64_t n_seqs_unq = ubatch->n_seqs_unq;
198
199 if (cparams.embeddings && (
200 cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
201 cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
202 cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
203 )) {
204 GGML_ASSERT(cls);
205 GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
206
207 uint32_t * data = (uint32_t *) cls->data;
208 memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
209
210 std::vector<int> target_pos(n_seqs_unq, -1);
211 std::vector<int> target_row(n_seqs_unq, -1);
212
213 const bool last = (
214 cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
215 (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
216 );
217
218 for (int i = 0; i < n_tokens; ++i) {
219 const llama_pos pos = ubatch->pos[i];
220
221 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
222 const llama_seq_id seq_id = ubatch->seq_id[i][s];
223 const int32_t seq_idx = ubatch->seq_idx[seq_id];
224
225 if (
226 (target_pos[seq_idx] == -1) ||
227 ( last && pos >= target_pos[seq_idx]) ||
228 (!last && pos < target_pos[seq_idx])
229 ) {
230 target_pos[seq_idx] = pos;
231 target_row[seq_idx] = i;
232 }
233 }
234 }
235
236 for (int s = 0; s < n_seqs_unq; ++s) {
237 if (target_row[s] >= 0) {
238 data[s] = target_row[s];
239 }
240 }
241 }
242}
243
244void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
245 GGML_UNUSED(ubatch);
246
247 const int64_t n_rs = mctx->get_n_rs();
248
249 if (s_copy) {
250 GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
251 int32_t * data = (int32_t *) s_copy->data;
252
253 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
254 for (uint32_t i = 0; i < n_rs; ++i) {
255 data[i] = mctx->s_copy(i);
256 }
257 }
258}
259
260bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
261 const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
262
263 this->mctx = mctx;
264
265 bool res = true;
266
267 res &= s_copy->ne[0] == mctx->get_n_rs();
268
269 res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
270 res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
271
272 res &= head == mctx->get_head();
273 res &= rs_z == mctx->get_rs_z();
274
275 return res;
276}
277
278void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
279 GGML_UNUSED(ubatch);
280
281 if (cross_embd && !cross->v_embd.empty()) {
282 assert(cross_embd->type == GGML_TYPE_F32);
283
284 ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
285 }
286}
287
288static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
289 LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
290 const char * swa_type_str = "unknown";
291
292 switch (swa_type) {
293 case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
294 case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
295 case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
296 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
297 };
298
299 LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
300 LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
301 LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
302
303 LLAMA_LOG_DEBUG(" ");
304 for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
305 LLAMA_LOG_DEBUG("%2d", j);
306 }
307 LLAMA_LOG_DEBUG("\n");
308
309 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
310 LLAMA_LOG_DEBUG(" %2d ", i);
311 for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
312 float val = data[i * n_kv + j];
313 if (val == -INFINITY) {
314 LLAMA_LOG_DEBUG(" ∞");
315 } else {
316 LLAMA_LOG_DEBUG(" 0");
317 }
318 }
319 LLAMA_LOG_DEBUG("\n");
320 }
321}
322
323void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
324 const int64_t n_kv = ubatch->n_tokens;
325 const int64_t n_tokens = ubatch->n_tokens;
326
327 const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
328 for (int i1 = 0; i1 < n_tokens; ++i1) {
329 const llama_seq_id s1 = ubatch->seq_id[i1][0];
330 const llama_pos p1 = ubatch->pos[i1];
331
332 const uint64_t idst = i1*n_kv;
333
334 for (int i0 = 0; i0 < n_tokens; ++i0) {
335 const llama_seq_id s0 = ubatch->seq_id[i0][0];
336 const llama_pos p0 = ubatch->pos[i0];
337
338 // mask different sequences
339 if (s0 != s1) {
340 continue;
341 }
342
343 // mask future tokens
344 if (cparams.causal_attn && p0 > p1) {
345 continue;
346 }
347
348 // apply SWA if any
349 if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
350 continue;
351 }
352
353 data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
354 }
355 }
356 };
357
358 {
359 GGML_ASSERT(self_kq_mask);
360 GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
361
362 float * data = (float *) self_kq_mask->data;
363
364 std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
365
366 fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
367
368 if (debug) {
369 print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
370 }
371 }
372
373 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
374 GGML_ASSERT(self_kq_mask_swa);
375 GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
376
377 float * data = (float *) self_kq_mask_swa->data;
378
379 std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
380
381 fill_mask(data, hparams.n_swa, hparams.swa_type);
382
383 if (debug) {
384 print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
385 }
386 }
387}
388
389void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
390 mctx->set_input_k_idxs(self_k_idxs, ubatch);
391 mctx->set_input_v_idxs(self_v_idxs, ubatch);
392
393 mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
394}
395
396bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
397 const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
398
399 this->mctx = mctx;
400
401 bool res = true;
402
403 res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
404 //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
405
406 res &= self_kq_mask->ne[0] == mctx->get_n_kv();
407 res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
408
409 return res;
410}
411
412void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
413 mctx->set_input_k_idxs(self_k_idxs, ubatch);
414
415 mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
416}
417
418bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
419 const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
420
421 this->mctx = mctx;
422
423 bool res = true;
424
425 res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
426
427 res &= self_kq_mask->ne[0] == mctx->get_n_kv();
428 res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
429
430 return res;
431}
432
433void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
434 mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
435 mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
436
437 mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
438
439 mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
440 mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
441
442 mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
443}
444
445bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
446 const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
447
448 this->mctx = mctx;
449
450 bool res = true;
451
452 res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
453 //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
454
455 res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
456 //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
457
458 res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
459 res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
460
461 res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
462 res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
463
464 return res;
465}
466
467void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
468 GGML_ASSERT(cross_kq_mask);
469
470 const int64_t n_enc = cross_kq_mask->ne[0];
471 const int64_t n_tokens = ubatch->n_tokens;
472
473 GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
474 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
475
476 float * data = (float *) cross_kq_mask->data;
477
478 for (int i = 0; i < n_tokens; ++i) {
479 for (int j = 0; j < n_enc; ++j) {
480 float f = -INFINITY;
481
482 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
483 const llama_seq_id seq_id = ubatch->seq_id[i][s];
484
485 if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
486 f = 0.0f;
487 }
488 }
489
490 data[i*n_enc + j] = f;
491 }
492 }
493}
494
495void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
496 mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
497 mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
498
499 mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
500
501 const int64_t n_rs = mctx->get_recr()->get_n_rs();
502
503 if (inp_rs->s_copy) {
504 GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
505 int32_t * data = (int32_t *) inp_rs->s_copy->data;
506
507 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
508 for (uint32_t i = 0; i < n_rs; ++i) {
509 data[i] = mctx->get_recr()->s_copy(i);
510 }
511 }
512}
513
514bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
515 const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
516
517 this->mctx = mctx;
518
519 bool res = true;
520
521 res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
522 //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
523
524 res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
525 res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
526
527 res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
528
529 res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
530 res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
531
532 res &= inp_rs->head == mctx->get_recr()->get_head();
533 res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
534
535 return res;
536}
537
538// TODO: Hybrid input classes are a bit redundant.
539// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
540// Refactoring is required in the future.
541void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
542 mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
543
544 mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
545
546 const int64_t n_rs = mctx->get_recr()->get_n_rs();
547
548 if (inp_rs->s_copy) {
549 GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
550 int32_t * data = (int32_t *) inp_rs->s_copy->data;
551
552 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
553 for (uint32_t i = 0; i < n_rs; ++i) {
554 data[i] = mctx->get_recr()->s_copy(i);
555 }
556 }
557}
558
559bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
560 const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
561
562 this->mctx = mctx;
563
564 bool res = true;
565
566 res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
567
568 res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
569 res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
570
571 res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
572
573 res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
574 res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
575
576 res &= inp_rs->head == mctx->get_recr()->get_head();
577 res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
578
579 return res;
580}
581
582void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
583 const auto * attn_ctx = mctx->get_attn();
584
585 // base tensors may not be allocated if there are no non-SWA attention layers
586 if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
587 attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
588 attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
589
590 attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
591 }
592
593 // swa tensors may not be allocated if there are no SWA attention layers
594 if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
595 attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
596 attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
597
598 attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
599 }
600
601 const int64_t n_rs = mctx->get_recr()->get_n_rs();
602
603 if (inp_rs->s_copy) {
604 GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
605 int32_t * data = (int32_t *) inp_rs->s_copy->data;
606
607 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
608 for (uint32_t i = 0; i < n_rs; ++i) {
609 data[i] = mctx->get_recr()->s_copy(i);
610 }
611 }
612}
613
614bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
615 const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
616
617 this->mctx = mctx;
618
619 bool res = true;
620
621 const auto * attn_ctx = mctx->get_attn();
622
623 // base tensors may not be allocated if there are no non-SWA attention layers
624 if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
625 res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
626 //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
627
628 res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
629 res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
630 }
631
632 // swa tensors may not be allocated if there are no SWA attention layers
633 if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
634 res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
635 //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
636
637 res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
638 res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
639 }
640
641 res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
642
643 res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
644 res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
645
646 res &= inp_rs->head == mctx->get_recr()->get_head();
647 res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
648
649 return res;
650}
651
652void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
653 // set the inputs only for the active samplers in the current ubatch
654 std::unordered_set<llama_seq_id> active_samplers;
655 for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
656 if (ubatch->output[i]) {
657 llama_seq_id seq_id = ubatch->seq_id[i][0];
658 active_samplers.insert(seq_id);
659 }
660 }
661
662 for (auto seq_id : active_samplers) {
663 if (samplers.find(seq_id) == samplers.end()) {
664 continue;
665 }
666
667 auto & sampler = samplers[seq_id];
668
669 if (sampler->iface->backend_set_input) {
670 sampler->iface->backend_set_input(sampler);
671 }
672 }
673}
674
675bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
676 if (samplers.size() != params.samplers.size()) {
677 return false;
678 }
679
680 for (const auto & [seq_id, sampler] : params.samplers) {
681 if (samplers[seq_id] != sampler) {
682 return false;
683 }
684 }
685
686 return true;
687}
688
689//
690// llm_graph_result
691//
692
693llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
694 reset();
695
696 const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
697 debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
698}
699
700int64_t llm_graph_result::get_max_nodes() const {
701 return max_nodes;
702}
703
704void llm_graph_result::reset() {
705 t_inp_tokens = nullptr;
706 t_inp_embd = nullptr;
707 t_logits = nullptr;
708 t_embd = nullptr;
709 t_embd_pooled = nullptr;
710 t_sampled.clear();
711 t_sampled_probs.clear();
712 t_sampled_logits.clear();
713 t_candidates.clear();
714
715 params = {};
716
717 inputs.clear();
718
719 buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
720
721 ggml_init_params params = {
722 /*.mem_size =*/ buf_compute_meta.size(),
723 /*.mem_buffer =*/ buf_compute_meta.data(),
724 /*.no_alloc =*/ true,
725 };
726
727 ctx_compute.reset(ggml_init(params));
728
729 gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
730}
731
732void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
733 for (auto & input : inputs) {
734 input->set_input(ubatch);
735 }
736}
737
738void llm_graph_result::set_outputs() {
739 if (t_logits != nullptr) {
740 ggml_set_output(t_logits);
741 }
742 if (t_embd != nullptr) {
743 ggml_set_output(t_embd);
744 }
745 if (t_embd_pooled != nullptr) {
746 ggml_set_output(t_embd_pooled);
747 }
748 for (auto & [seq_id, t] : t_sampled) {
749 if (t != nullptr) {
750 ggml_set_output(t);
751 }
752 }
753 for (auto & [seq_id, t] : t_sampled_probs) {
754 if (t != nullptr) {
755 ggml_set_output(t);
756 }
757 }
758 for (auto & [seq_id, t] : t_sampled_logits) {
759 if (t != nullptr) {
760 ggml_set_output(t);
761 }
762 }
763 for (auto & [seq_id, t] : t_candidates) {
764 if (t != nullptr) {
765 ggml_set_output(t);
766 }
767 }
768}
769
770bool llm_graph_result::can_reuse(const llm_graph_params & params) {
771 if (!this->params.allow_reuse(params)) {
772 if (debug > 1) {
773 LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
774 }
775
776 return false;
777 }
778
779 if (debug > 1) {
780 LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
781 }
782
783 bool res = true;
784
785 for (auto & input : inputs) {
786 const bool cur = input->can_reuse(params);
787
788 if (debug > 1) {
789 LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
790 }
791
792 res = res && cur;
793 }
794
795 if (debug > 0) {
796 LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
797 }
798
799 return res;
800}
801
802llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
803 inputs.emplace_back(std::move(input));
804 return inputs.back().get();
805}
806
807void llm_graph_result::set_params(const llm_graph_params & params) {
808 this->params = params;
809}
810
811//
812// llm_graph_context
813//
814
815llm_graph_context::llm_graph_context(const llm_graph_params & params) :
816 arch (params.arch),
817 hparams (params.hparams),
818 cparams (params.cparams),
819 ubatch (params.ubatch),
820 n_embd (hparams.n_embd),
821 n_layer (hparams.n_layer),
822 n_rot (hparams.n_rot),
823 n_ctx (cparams.n_ctx),
824 n_head (hparams.n_head()),
825 n_head_kv (hparams.n_head_kv()),
826 n_embd_head_k (hparams.n_embd_head_k),
827 n_embd_k_gqa (hparams.n_embd_k_gqa()),
828 n_embd_head_v (hparams.n_embd_head_v),
829 n_embd_v_gqa (hparams.n_embd_v_gqa()),
830 n_expert (hparams.n_expert),
831 n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
832 freq_base (cparams.rope_freq_base),
833 freq_scale (cparams.rope_freq_scale),
834 ext_factor (cparams.yarn_ext_factor),
835 attn_factor (cparams.yarn_attn_factor),
836 beta_fast (cparams.yarn_beta_fast),
837 beta_slow (cparams.yarn_beta_slow),
838 norm_eps (hparams.f_norm_eps),
839 norm_rms_eps (hparams.f_norm_rms_eps),
840 n_tokens (ubatch.n_tokens),
841 n_outputs (params.n_outputs),
842 n_ctx_orig (cparams.n_ctx_orig_yarn),
843 pooling_type (cparams.pooling_type),
844 rope_type (hparams.rope_type),
845 sched (params.sched),
846 backend_cpu (params.backend_cpu),
847 cvec (params.cvec),
848 loras (params.loras),
849 mctx (params.mctx),
850 cross (params.cross),
851 samplers (params.samplers),
852 cb_func (params.cb),
853 res (params.res),
854 ctx0 (res->get_ctx()),
855 gf (res->get_gf()) {
856 res->set_params(params);
857 }
858
859void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
860 if (cb_func) {
861 cb_func(ubatch, cur, name, il);
862 }
863}
864
865ggml_tensor * llm_graph_context::build_cvec(
866 ggml_tensor * cur,
867 int il) const {
868 return cvec->apply_to(ctx0, cur, il);
869}
870
871ggml_tensor * llm_graph_context::build_lora_mm(
872 ggml_tensor * w,
873 ggml_tensor * cur) const {
874 ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
875
876 for (const auto & lora : *loras) {
877 llama_adapter_lora_weight * lw = lora.first->get_weight(w);
878 if (lw == nullptr) {
879 continue;
880 }
881
882 const float adapter_scale = lora.second;
883 const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
884
885 ggml_tensor * ab_cur = ggml_mul_mat(
886 ctx0, lw->b,
887 ggml_mul_mat(ctx0, lw->a, cur)
888 );
889
890 ab_cur = ggml_scale(ctx0, ab_cur, scale);
891 res = ggml_add(ctx0, res, ab_cur);
892 }
893
894 return res;
895}
896
897ggml_tensor * llm_graph_context::build_lora_mm_id(
898 ggml_tensor * w, // ggml_tensor * as
899 ggml_tensor * cur, // ggml_tensor * b
900 ggml_tensor * ids) const {
901 ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
902 for (const auto & lora : *loras) {
903 llama_adapter_lora_weight * lw = lora.first->get_weight(w);
904 if (lw == nullptr) {
905 continue;
906 }
907
908 const float alpha = lora.first->alpha;
909 const float rank = (float) lw->b->ne[0];
910 const float scale = alpha ? lora.second * alpha / rank : lora.second;
911
912 ggml_tensor * ab_cur = ggml_mul_mat_id(
913 ctx0, lw->b,
914 ggml_mul_mat_id(ctx0, lw->a, cur, ids),
915 ids
916 );
917
918 ab_cur = ggml_scale(ctx0, ab_cur, scale);
919 res = ggml_add(ctx0, res, ab_cur);
920 }
921
922 return res;
923}
924
925ggml_tensor * llm_graph_context::build_norm(
926 ggml_tensor * cur,
927 ggml_tensor * mw,
928 ggml_tensor * mb,
929 llm_norm_type type,
930 int il) const {
931 switch (type) {
932 case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
933 case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
934 case LLM_NORM_GROUP:
935 {
936 cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
937 cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
938 cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
939 } break;
940 }
941
942 if (mw || mb) {
943 cb(cur, "norm", il);
944 }
945
946 if (mw) {
947 cur = ggml_mul(ctx0, cur, mw);
948 if (mb) {
949 cb(cur, "norm_w", il);
950 }
951 }
952
953 if (mb) {
954 cur = ggml_add(ctx0, cur, mb);
955 }
956
957 return cur;
958}
959
960ggml_tensor * llm_graph_context::build_ffn(
961 ggml_tensor * cur,
962 ggml_tensor * up,
963 ggml_tensor * up_b,
964 ggml_tensor * up_s,
965 ggml_tensor * gate,
966 ggml_tensor * gate_b,
967 ggml_tensor * gate_s,
968 ggml_tensor * down,
969 ggml_tensor * down_b,
970 ggml_tensor * down_s,
971 ggml_tensor * act_scales,
972 llm_ffn_op_type type_op,
973 llm_ffn_gate_type type_gate,
974 int il) const {
975 ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
976 cb(tmp, "ffn_up", il);
977
978 if (up_b) {
979 tmp = ggml_add(ctx0, tmp, up_b);
980 cb(tmp, "ffn_up_b", il);
981 }
982
983 if (up_s) {
984 tmp = ggml_mul(ctx0, tmp, up_s);
985 cb(tmp, "ffn_up_s", il);
986 }
987
988 if (gate) {
989 switch (type_gate) {
990 case LLM_FFN_SEQ:
991 {
992 cur = build_lora_mm(gate, tmp);
993 cb(cur, "ffn_gate", il);
994 } break;
995 case LLM_FFN_PAR:
996 {
997 cur = build_lora_mm(gate, cur);
998 cb(cur, "ffn_gate", il);
999 } break;
1000 }
1001
1002 if (gate_b) {
1003 cur = ggml_add(ctx0, cur, gate_b);
1004 cb(cur, "ffn_gate_b", il);
1005 }
1006
1007 if (gate_s) {
1008 cur = ggml_mul(ctx0, cur, gate_s);
1009 cb(cur, "ffn_gate_s", il);
1010 }
1011
1012 } else {
1013 cur = tmp;
1014 }
1015
1016 switch (type_op) {
1017 case LLM_FFN_SILU:
1018 if (gate && type_gate == LLM_FFN_PAR) {
1019 // Step35: HF clamps gate (after SiLU) and up before multiplication
1020 if (arch == LLM_ARCH_STEP35 && il >= 0) {
1021 const float limit = hparams.swiglu_clamp_shexp[il];
1022 constexpr float eps = 1e-6f;
1023 if (limit > eps) {
1024 ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1025 cb(gate_act, "ffn_silu", il);
1026 gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1027 cb(gate_act, "ffn_silu_clamped", il);
1028
1029 tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1030 cb(tmp, "ffn_up_clamped", il);
1031
1032 cur = ggml_mul(ctx0, gate_act, tmp);
1033 cb(cur, "ffn_swiglu_limited", il);
1034 type_gate = LLM_FFN_SEQ;
1035 break;
1036 }
1037 }
1038
1039 cur = ggml_swiglu_split(ctx0, cur, tmp);
1040 cb(cur, "ffn_swiglu", il);
1041 type_gate = LLM_FFN_SEQ;
1042 } else {
1043 cur = ggml_silu(ctx0, cur);
1044 cb(cur, "ffn_silu", il);
1045 } break;
1046 case LLM_FFN_GELU:
1047 if (gate && type_gate == LLM_FFN_PAR) {
1048 cur = ggml_geglu_split(ctx0, cur, tmp);
1049 cb(cur, "ffn_geglu", il);
1050 type_gate = LLM_FFN_SEQ;
1051 } else {
1052 cur = ggml_gelu(ctx0, cur);
1053 cb(cur, "ffn_gelu", il);
1054 if (act_scales != NULL) {
1055 cur = ggml_div(ctx0, cur, act_scales);
1056 cb(cur, "ffn_act", il);
1057 }
1058 } break;
1059 case LLM_FFN_RELU:
1060 if (gate && type_gate == LLM_FFN_PAR) {
1061 cur = ggml_reglu_split(ctx0, cur, tmp);
1062 cb(cur, "ffn_reglu", il);
1063 type_gate = LLM_FFN_SEQ;
1064 } else {
1065 cur = ggml_relu(ctx0, cur);
1066 cb(cur, "ffn_relu", il);
1067 } break;
1068 case LLM_FFN_RELU_SQR:
1069 {
1070 cur = ggml_relu(ctx0, cur);
1071 cb(cur, "ffn_relu", il);
1072
1073 cur = ggml_sqr(ctx0, cur);
1074 cb(cur, "ffn_sqr(relu)", il);
1075 } break;
1076 case LLM_FFN_SWIGLU:
1077 {
1078 cur = ggml_swiglu(ctx0, cur);
1079 cb(cur, "ffn_swiglu", il);
1080 } break;
1081 case LLM_FFN_GEGLU:
1082 {
1083 cur = ggml_geglu(ctx0, cur);
1084 cb(cur, "ffn_geglu", il);
1085 } break;
1086 case LLM_FFN_REGLU:
1087 {
1088 cur = ggml_reglu(ctx0, cur);
1089 cb(cur, "ffn_reglu", il);
1090 } break;
1091 default:
1092 GGML_ABORT("fatal error");
1093 }
1094
1095 if (gate && type_gate == LLM_FFN_PAR) {
1096 cur = ggml_mul(ctx0, cur, tmp);
1097 cb(cur, "ffn_gate_par", il);
1098 }
1099
1100 if (down) {
1101 cur = build_lora_mm(down, cur);
1102 if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1103 // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1104 ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1105 }
1106 }
1107
1108 if (down_b) {
1109 cb(cur, "ffn_down", il);
1110 }
1111
1112 if (down_b) {
1113 cur = ggml_add(ctx0, cur, down_b);
1114 }
1115
1116 if (down_s) {
1117 cur = ggml_mul(ctx0, cur, down_s);
1118 cb(cur, "ffn_down_s", il);
1119 }
1120
1121 return cur;
1122}
1123
1124ggml_tensor * llm_graph_context::build_moe_ffn(
1125 ggml_tensor * cur,
1126 ggml_tensor * gate_inp,
1127 ggml_tensor * up_exps,
1128 ggml_tensor * gate_exps,
1129 ggml_tensor * down_exps,
1130 ggml_tensor * exp_probs_b,
1131 int64_t n_expert,
1132 int64_t n_expert_used,
1133 llm_ffn_op_type type_op,
1134 bool norm_w,
1135 bool scale_w,
1136 float w_scale,
1137 llama_expert_gating_func_type gating_op,
1138 int il,
1139 ggml_tensor * probs_in) const {
1140 return build_moe_ffn(
1141 cur,
1142 gate_inp, /* gate_inp_b */ nullptr,
1143 up_exps, /* up_exps_b */ nullptr,
1144 gate_exps, /* gate_exps_b */ nullptr,
1145 down_exps, /* down_exps_b */ nullptr,
1146 exp_probs_b,
1147 n_expert,
1148 n_expert_used,
1149 type_op,
1150 norm_w,
1151 scale_w,
1152 w_scale,
1153 gating_op,
1154 il,
1155 probs_in
1156 );
1157}
1158
1159ggml_tensor * llm_graph_context::build_moe_ffn(
1160 ggml_tensor * cur,
1161 ggml_tensor * gate_inp,
1162 ggml_tensor * gate_inp_b,
1163 ggml_tensor * up_exps,
1164 ggml_tensor * up_exps_b,
1165 ggml_tensor * gate_exps,
1166 ggml_tensor * gate_exps_b,
1167 ggml_tensor * down_exps,
1168 ggml_tensor * down_exps_b,
1169 ggml_tensor * exp_probs_b,
1170 int64_t n_expert,
1171 int64_t n_expert_used,
1172 llm_ffn_op_type type_op,
1173 bool norm_w,
1174 bool scale_w,
1175 float w_scale,
1176 llama_expert_gating_func_type gating_op,
1177 int il,
1178 ggml_tensor * probs_in) const {
1179 const int64_t n_embd = cur->ne[0];
1180 const int64_t n_tokens = cur->ne[1];
1181 const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1182
1183 ggml_tensor * logits = nullptr;
1184
1185 if (probs_in == nullptr) {
1186 logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1187 cb(logits, "ffn_moe_logits", il);
1188 } else {
1189 logits = probs_in;
1190 }
1191
1192 if (gate_inp_b) {
1193 logits = ggml_add(ctx0, logits, gate_inp_b);
1194 cb(logits, "ffn_moe_logits_biased", il);
1195 }
1196
1197 ggml_tensor * probs = nullptr;
1198 switch (gating_op) {
1199 case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1200 {
1201 probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1202 } break;
1203 case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1204 {
1205 probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1206 } break;
1207 case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1208 {
1209 probs = logits; // [n_expert, n_tokens]
1210 } break;
1211 default:
1212 GGML_ABORT("fatal error");
1213 }
1214 cb(probs, "ffn_moe_probs", il);
1215
1216 // add experts selection bias - introduced in DeepSeek V3
1217 // leave probs unbiased as it's later used to get expert weights
1218 ggml_tensor * selection_probs = probs;
1219 if (exp_probs_b != nullptr) {
1220 selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1221 cb(selection_probs, "ffn_moe_probs_biased", il);
1222 }
1223
1224 // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1225 // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1226 if (arch == LLM_ARCH_LLAMA4) {
1227 selection_probs = logits;
1228 }
1229
1230 if (arch == LLM_ARCH_GROVEMOE) {
1231 selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1232 cb(selection_probs, "ffn_moe_probs_biased", il);
1233 }
1234
1235 // select top n_group_used expert groups
1236 // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1237 if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1238 const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1239
1240 // organize experts into n_expert_groups
1241 ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1242
1243 ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1244 group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1245
1246 // get top n_group_used expert groups
1247 group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1248 group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1249
1250 ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1251 cb(expert_groups, "ffn_moe_group_topk", il);
1252
1253 // mask out the other groups
1254 selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1255 selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1256 selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1257 cb(selection_probs, "ffn_moe_probs_masked", il);
1258 }
1259
1260 // select experts
1261 ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1262 cb(selected_experts->src[0], "ffn_moe_argsort", il);
1263 cb(selected_experts, "ffn_moe_topk", il);
1264
1265 if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1266 // TODO: Use scalar div instead when/if implemented
1267 ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1268 selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1269 probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1270 } else {
1271 probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1272 }
1273
1274 ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1275 cb(weights, "ffn_moe_weights", il);
1276
1277
1278 if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1279 weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1280 weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1281 weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1282 cb(weights, "ffn_moe_weights_softmax", il);
1283 }
1284
1285 if (norm_w) {
1286 weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1287
1288 ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1289 cb(weights_sum, "ffn_moe_weights_sum", il);
1290
1291 // Avoid division by zero, clamp to smallest number representable by F16
1292 weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1293 cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1294
1295 weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1296 cb(weights, "ffn_moe_weights_norm", il);
1297
1298 weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1299 }
1300 if (scale_w) {
1301 weights = ggml_scale(ctx0, weights, w_scale);
1302 cb(weights, "ffn_moe_weights_scaled", il);
1303 }
1304
1305 //call early so that topk-moe can be used
1306 ggml_build_forward_expand(gf, weights);
1307
1308 cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1309
1310 if (weight_before_ffn) {
1311 // repeat cur to [n_embd, n_expert_used, n_tokens]
1312 ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1313 cur = ggml_mul(ctx0, repeated, weights);
1314 cb(cur, "ffn_moe_weighted", il);
1315 }
1316
1317 ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1318 cb(up, "ffn_moe_up", il);
1319
1320 if (up_exps_b) {
1321 up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1322 cb(up, "ffn_moe_up_biased", il);
1323 }
1324
1325 ggml_tensor * experts = nullptr;
1326 if (gate_exps) {
1327 cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1328 cb(cur, "ffn_moe_gate", il);
1329 } else {
1330 cur = up;
1331 }
1332
1333 if (gate_exps_b) {
1334 cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1335 cb(cur, "ffn_moe_gate_biased", il);
1336 }
1337
1338 switch (type_op) {
1339 case LLM_FFN_SILU:
1340 if (gate_exps) {
1341 // Step35: per-layer clamp for routed experts
1342 if (arch == LLM_ARCH_STEP35 && il >= 0) {
1343 const float limit = hparams.swiglu_clamp_exp[il];
1344 constexpr float eps = 1e-6f;
1345 if (limit > eps) {
1346 ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1347 cb(gate_act, "ffn_moe_silu", il);
1348 gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1349 cb(gate_act, "ffn_moe_silu_clamped", il);
1350
1351 up = ggml_clamp(ctx0, up, -limit, limit);
1352 cb(up, "ffn_moe_up_clamped", il);
1353
1354 cur = ggml_mul(ctx0, gate_act, up);
1355 cb(cur, "ffn_moe_swiglu_limited", il);
1356 break;
1357 }
1358 }
1359
1360 cur = ggml_swiglu_split(ctx0, cur, up);
1361 cb(cur, "ffn_moe_swiglu", il);
1362 } else {
1363 cur = ggml_silu(ctx0, cur);
1364 cb(cur, "ffn_moe_silu", il);
1365 } break;
1366 case LLM_FFN_GELU:
1367 if (gate_exps) {
1368 cur = ggml_geglu_split(ctx0, cur, up);
1369 cb(cur, "ffn_moe_geglu", il);
1370 } else {
1371 cur = ggml_gelu(ctx0, cur);
1372 cb(cur, "ffn_moe_gelu", il);
1373 } break;
1374 case LLM_FFN_SWIGLU_OAI_MOE:
1375 {
1376 // TODO: move to hparams?
1377 constexpr float alpha = 1.702f;
1378 constexpr float limit = 7.0f;
1379 cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1380 cb(cur, "ffn_moe_swiglu_oai", il);
1381 } break;
1382 case LLM_FFN_RELU:
1383 if (gate_exps) {
1384 cur = ggml_reglu_split(ctx0, cur, up);
1385 cb(cur, "ffn_moe_reglu", il);
1386 } else {
1387 cur = ggml_relu(ctx0, cur);
1388 cb(cur, "ffn_moe_relu", il);
1389 } break;
1390 case LLM_FFN_RELU_SQR:
1391 if (gate_exps) {
1392 // TODO: add support for gated squared relu
1393 GGML_ABORT("fatal error: gated squared relu not implemented");
1394 } else {
1395 cur = ggml_relu(ctx0, cur);
1396 cur = ggml_sqr(ctx0, cur);
1397 cb(cur, "ffn_moe_relu_sqr", il);
1398 } break;
1399 default:
1400 GGML_ABORT("fatal error");
1401 }
1402
1403 experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1404 cb(experts, "ffn_moe_down", il);
1405
1406 if (down_exps_b) {
1407 experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1408 cb(experts, "ffn_moe_down_biased", il);
1409 }
1410
1411 if (!weight_before_ffn) {
1412 experts = ggml_mul(ctx0, experts, weights);
1413 cb(cur, "ffn_moe_weighted", il);
1414 }
1415
1416 ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1417
1418 assert(n_expert_used > 0);
1419
1420 // order the views before the adds
1421 for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1422 cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1423
1424 ggml_build_forward_expand(gf, cur_experts[i]);
1425 }
1426
1427 // aggregate experts
1428 // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1429 // to avoid potentially a large number of add nodes during warmup
1430 // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1431 ggml_tensor * moe_out = cur_experts[0];
1432
1433 for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1434 moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1435 }
1436
1437 if (hparams.n_expert_used == 1) {
1438 // avoid returning a non-contiguous tensor
1439 moe_out = ggml_cont(ctx0, moe_out);
1440 }
1441
1442 cb(moe_out, "ffn_moe_out", il);
1443
1444 return moe_out;
1445}
1446
1447// input embeddings with optional lora
1448ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1449 const int64_t n_embd_inp = hparams.n_embd_inp();
1450 const int64_t n_embd = hparams.n_embd;
1451
1452 assert(n_embd_inp >= n_embd);
1453
1454 auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1455
1456 inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1457 cb(inp->tokens, "inp_tokens", -1);
1458 ggml_set_input(inp->tokens);
1459 res->t_inp_tokens = inp->tokens;
1460
1461 inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1462 cb(inp->embd, "inp_embd", -1);
1463 ggml_set_input(inp->embd);
1464
1465 // select one of the 2 inputs, based on the batch contents
1466 // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1467 std::array<ggml_tensor *, 2> inps;
1468
1469 // token embeddings path (ubatch.token != nullptr)
1470 {
1471 auto & cur = inps[0];
1472
1473 cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1474
1475 // apply lora for embedding tokens if needed
1476 for (const auto & lora : *loras) {
1477 llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1478 if (lw == nullptr) {
1479 continue;
1480 }
1481
1482 const float adapter_scale = lora.second;
1483 const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1484
1485 ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1486 ctx0, lw->b, // non-transposed lora_b
1487 ggml_get_rows(ctx0, lw->a, inp->tokens)
1488 ), scale);
1489
1490 cur = ggml_add(ctx0, cur, inpL_delta);
1491 }
1492
1493 if (n_embd_inp != n_embd) {
1494 cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1495 }
1496 }
1497
1498 // vector embeddings path (ubatch.embd != nullptr)
1499 {
1500 auto & cur = inps[1];
1501
1502 cur = inp->embd;
1503 }
1504
1505 assert(ggml_are_same_shape (inps[0], inps[1]));
1506 assert(ggml_are_same_stride(inps[0], inps[1]));
1507
1508 ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1509
1510 if (n_embd_inp != n_embd) {
1511 cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1512 }
1513
1514 res->t_inp_embd = cur;
1515
1516 // For Granite architecture
1517 if (hparams.f_embedding_scale != 0.0f) {
1518 cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1519 }
1520
1521 cb(cur, "embd", -1);
1522
1523 res->add_input(std::move(inp));
1524
1525 // make sure the produced embeddings are immediately materialized in the ggml graph
1526 // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1527 ggml_build_forward_expand(gf, cur);
1528
1529 return cur;
1530}
1531
1532ggml_tensor * llm_graph_context::build_inp_pos() const {
1533 auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1534
1535 auto & cur = inp->pos;
1536
1537 cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1538 ggml_set_input(cur);
1539
1540 res->add_input(std::move(inp));
1541
1542 return cur;
1543}
1544
1545ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1546 auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1547
1548 auto & cur = inp->attn_scale;
1549
1550 // this need to be 1x1xN for broadcasting
1551 cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1552 ggml_set_input(cur);
1553
1554 res->add_input(std::move(inp));
1555
1556 return cur;
1557}
1558
1559ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1560 // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1561 // but this would make the graph topology depend on the number of output tokens, which can interere with
1562 // features that require constant topology such as pipline parallelism
1563 // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1564 //if (n_outputs < n_tokens) {
1565 // return nullptr;
1566 //}
1567
1568 auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1569
1570 auto & cur = inp->out_ids;
1571
1572 cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1573 ggml_set_input(cur);
1574
1575 res->add_input(std::move(inp));
1576
1577 return cur;
1578}
1579
1580ggml_tensor * llm_graph_context::build_inp_mean() const {
1581 auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1582
1583 auto & cur = inp->mean;
1584
1585 cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1586 ggml_set_input(cur);
1587
1588 res->add_input(std::move(inp));
1589
1590 return cur;
1591}
1592
1593ggml_tensor * llm_graph_context::build_inp_cls() const {
1594 auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1595
1596 auto & cur = inp->cls;
1597
1598 cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1599 ggml_set_input(cur);
1600
1601 res->add_input(std::move(inp));
1602
1603 return cur;
1604}
1605
1606ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1607 auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1608
1609 auto & cur = inp->cross_embd;
1610
1611 // if we have the output embeddings from the encoder, use them directly
1612 // TODO: needs more work to be correct, for now just use the tensor shape
1613 //if (cross->t_embd) {
1614 // cur = ggml_view_tensor(ctx0, cross->t_embd);
1615
1616 // return cur;
1617 //}
1618
1619 const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1620 const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1621
1622 cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1623 ggml_set_input(cur);
1624
1625 res->add_input(std::move(inp));
1626
1627 return cur;
1628}
1629
1630ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1631 auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
1632
1633 auto & cur = inp->pos_bucket;
1634
1635 cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1636 ggml_set_input(cur);
1637
1638 res->add_input(std::move(inp));
1639
1640 return cur;
1641}
1642
1643ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1644 const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1645
1646 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1647
1648 const auto n_kv = mctx_cur->get_n_kv();
1649
1650 auto & cur = inp->pos_bucket;
1651
1652 cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
1653 ggml_set_input(cur);
1654
1655 res->add_input(std::move(inp));
1656
1657 return cur;
1658}
1659
1660ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1661 ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
1662 cb(pos_bucket_1d, "pos_bucket_1d", -1);
1663
1664 ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1665
1666 pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1667 pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
1668 pos_bias = ggml_cont (ctx0, pos_bias);
1669
1670 cb(pos_bias, "pos_bias", -1);
1671
1672 return pos_bias;
1673}
1674
1675ggml_tensor * llm_graph_context::build_attn_mha(
1676 ggml_tensor * q,
1677 ggml_tensor * k,
1678 ggml_tensor * v,
1679 ggml_tensor * kq_b,
1680 ggml_tensor * kq_mask,
1681 ggml_tensor * sinks,
1682 ggml_tensor * v_mla,
1683 float kq_scale,
1684 int il) const {
1685 const bool v_trans = v->nb[1] > v->nb[2];
1686
1687 // split the batch into streams if needed
1688 const auto n_stream = k->ne[3];
1689
1690 q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1691
1692 q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1693 k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1694 v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1695
1696 ggml_tensor * cur;
1697
1698 if (cparams.flash_attn && kq_b == nullptr) {
1699 GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1700
1701 if (v_trans) {
1702 v = ggml_transpose(ctx0, v);
1703 }
1704
1705 // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1706 if (k->type == GGML_TYPE_F32) {
1707 k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1708 }
1709
1710 if (v->type == GGML_TYPE_F32) {
1711 v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1712 }
1713
1714 cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1715 hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1716 cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1717
1718 ggml_flash_attn_ext_add_sinks(cur, sinks);
1719 ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1720
1721 if (v_mla) {
1722#if 0
1723 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1724 // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1725 cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1726 cur = ggml_mul_mat(ctx0, v_mla, cur);
1727#else
1728 // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1729 // The permutations are noops and only change how the tensor data is interpreted.
1730 cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1731 cur = ggml_mul_mat(ctx0, v_mla, cur);
1732 cb(cur, "fattn_mla", il);
1733 cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1734 cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1735#endif
1736 }
1737
1738 cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1739 } else {
1740 ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1741 cb(kq, "kq", il);
1742
1743 // note: this op tends to require high floating point range
1744 // while for some models F16 is enough, for others it is not, so we default to F32 here
1745 ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1746
1747 if (arch == LLM_ARCH_GROK) {
1748 // need to do the following:
1749 // multiply by attn_output_multiplier
1750 // and then :
1751 // kq = 30 * tanh(kq / 30)
1752 // before the softmax below
1753
1754 kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1755 cb(kq, "kq_tanh", il);
1756 kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1757 cb(kq, "kq_scaled", il);
1758 }
1759
1760 if (hparams.attn_soft_cap) {
1761 kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1762 cb(kq, "kq_scaled_1", il);
1763 kq = ggml_tanh (ctx0, kq);
1764 cb(kq, "kq_tanh", il);
1765 kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1766 cb(kq, "kq_scaled_2", il);
1767 }
1768
1769 if (kq_b) {
1770 kq = ggml_add(ctx0, kq, kq_b);
1771 cb(kq, "kq_plus_kq_b", il);
1772 }
1773
1774 kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1775 ggml_soft_max_add_sinks(kq, sinks);
1776 cb(kq, "kq_soft_max", il);
1777
1778 if (!v_trans) {
1779 // note: avoid this branch
1780 v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1781 cb(v, "v_cont", il);
1782 }
1783
1784 ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1785 cb(kqv, "kqv", il);
1786
1787 // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1788 if (v_mla) {
1789 kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1790 cb(kqv, "kqv_mla", il);
1791 }
1792
1793 cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1794
1795 // recombine streams
1796 cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1797
1798 if (!cparams.offload_kqv) {
1799 // all nodes between the KV store and the attention output are run on the CPU
1800 ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1801 }
1802 }
1803
1804 ggml_build_forward_expand(gf, cur);
1805
1806 return cur;
1807}
1808
1809llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1810 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1811
1812 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1813 inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1814 ggml_set_input(inp->self_kq_mask);
1815
1816 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1817
1818 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1819 inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1820 ggml_set_input(inp->self_kq_mask_swa);
1821
1822 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1823 } else {
1824 inp->self_kq_mask_swa = nullptr;
1825 inp->self_kq_mask_swa_cnv = nullptr;
1826 }
1827
1828 return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1829}
1830
1831ggml_tensor * llm_graph_context::build_attn(
1832 llm_graph_input_attn_no_cache * inp,
1833 ggml_tensor * wo,
1834 ggml_tensor * wo_b,
1835 ggml_tensor * q_cur,
1836 ggml_tensor * k_cur,
1837 ggml_tensor * v_cur,
1838 ggml_tensor * kq_b,
1839 ggml_tensor * sinks,
1840 ggml_tensor * v_mla,
1841 float kq_scale,
1842 int il) const {
1843 GGML_UNUSED(n_tokens);
1844
1845 // these nodes are added to the graph together so that they are not reordered
1846 // by doing so, the number of splits in the graph is reduced
1847 ggml_build_forward_expand(gf, q_cur);
1848 ggml_build_forward_expand(gf, k_cur);
1849 ggml_build_forward_expand(gf, v_cur);
1850
1851 const bool is_swa = hparams.is_swa(il);
1852
1853 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1854
1855 // [TAG_NO_CACHE_PAD]
1856 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1857 // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1858 //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1859
1860 ggml_tensor * q = q_cur;
1861 ggml_tensor * k = k_cur;
1862 ggml_tensor * v = v_cur;
1863
1864 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1865 cb(cur, "kqv_out", il);
1866
1867 if (wo) {
1868 cur = build_lora_mm(wo, cur);
1869 }
1870
1871 if (wo_b) {
1872 //cb(cur, "kqv_wo", il);
1873 }
1874
1875 if (wo_b) {
1876 cur = ggml_add(ctx0, cur, wo_b);
1877 }
1878
1879 return cur;
1880}
1881
1882static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1883 ggml_context * ctx0,
1884 const llama_ubatch & ubatch,
1885 const llama_hparams & hparams,
1886 const llama_cparams & cparams,
1887 const llama_kv_cache_context * mctx_cur) {
1888
1889 auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1890
1891 {
1892 GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1893
1894 const auto n_kv = mctx_cur->get_n_kv();
1895 const auto n_tokens = ubatch.n_tokens;
1896 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1897
1898 inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1899 inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1900
1901 inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1902 ggml_set_input(inp->self_kq_mask);
1903
1904 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1905 }
1906
1907 return inp;
1908}
1909
1910llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1911 const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1912
1913 auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1914
1915 return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1916}
1917
1918ggml_tensor * llm_graph_context::build_attn(
1919 llm_graph_input_attn_kv * inp,
1920 ggml_tensor * wo,
1921 ggml_tensor * wo_b,
1922 ggml_tensor * q_cur,
1923 ggml_tensor * k_cur,
1924 ggml_tensor * v_cur,
1925 ggml_tensor * kq_b,
1926 ggml_tensor * sinks,
1927 ggml_tensor * v_mla, // TODO: remove
1928 float kq_scale,
1929 int il) const {
1930 GGML_ASSERT(v_mla == nullptr);
1931
1932 // these nodes are added to the graph together so that they are not reordered
1933 // by doing so, the number of splits in the graph is reduced
1934 // expand k later to enable rope fusion which directly writes into k-v cache
1935 ggml_build_forward_expand(gf, q_cur);
1936 ggml_build_forward_expand(gf, v_cur);
1937 ggml_build_forward_expand(gf, k_cur);
1938
1939 const auto * mctx_cur = inp->mctx;
1940
1941 // store to KV cache
1942 {
1943 const auto & k_idxs = inp->get_k_idxs();
1944 const auto & v_idxs = inp->get_v_idxs();
1945
1946 ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1947 ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1948 }
1949
1950 const auto & kq_mask = inp->get_kq_mask();
1951
1952 ggml_tensor * q = q_cur;
1953 ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1954 ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1955
1956 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1957 cb(cur, "kqv_out", il);
1958
1959 if (wo) {
1960 cur = build_lora_mm(wo, cur);
1961 if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1962 // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1963 ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1964 }
1965 }
1966
1967 if (wo_b) {
1968 cur = ggml_add(ctx0, cur, wo_b);
1969 }
1970
1971 return cur;
1972}
1973
1974static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
1975 ggml_context * ctx0,
1976 const llama_ubatch & ubatch,
1977 const llama_hparams & hparams,
1978 const llama_cparams & cparams,
1979 const llama_kv_cache_context * mctx_cur) {
1980
1981 auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
1982
1983 {
1984 GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1985
1986 const auto n_kv = mctx_cur->get_n_kv();
1987 const auto n_tokens = ubatch.n_tokens;
1988 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1989
1990 inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1991
1992 inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1993 ggml_set_input(inp->self_kq_mask);
1994
1995 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1996 }
1997
1998 return inp;
1999}
2000
2001llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2002 const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2003
2004 auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2005
2006 return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2007}
2008
2009ggml_tensor * llm_graph_context::build_attn(
2010 llm_graph_input_attn_k * inp,
2011 ggml_tensor * wo,
2012 ggml_tensor * wo_b,
2013 ggml_tensor * q_cur,
2014 ggml_tensor * k_cur,
2015 ggml_tensor * v_cur,
2016 ggml_tensor * kq_b,
2017 ggml_tensor * sinks,
2018 ggml_tensor * v_mla,
2019 float kq_scale,
2020 int il) const {
2021 // these nodes are added to the graph together so that they are not reordered
2022 // by doing so, the number of splits in the graph is reduced
2023 // expand k later to enable rope fusion which directly writes into k-v cache
2024 ggml_build_forward_expand(gf, q_cur);
2025 ggml_build_forward_expand(gf, v_cur);
2026 ggml_build_forward_expand(gf, k_cur);
2027
2028 const auto * mctx_cur = inp->mctx;
2029
2030 // store to KV cache
2031 {
2032 const auto & k_idxs = inp->get_k_idxs();
2033
2034 ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2035 }
2036
2037 const auto & kq_mask = inp->get_kq_mask();
2038
2039 ggml_tensor * q = q_cur;
2040 ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2041 ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2042
2043 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2044 cb(cur, "kqv_out", il);
2045
2046 if (wo) {
2047 cur = build_lora_mm(wo, cur);
2048 if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2049 // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2050 ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2051 }
2052 }
2053
2054 if (wo_b) {
2055 cur = ggml_add(ctx0, cur, wo_b);
2056 }
2057
2058 return cur;
2059}
2060
2061ggml_tensor * llm_graph_context::build_attn(
2062 llm_graph_input_attn_kv_iswa * inp,
2063 ggml_tensor * wo,
2064 ggml_tensor * wo_b,
2065 ggml_tensor * q_cur,
2066 ggml_tensor * k_cur,
2067 ggml_tensor * v_cur,
2068 ggml_tensor * kq_b,
2069 ggml_tensor * sinks,
2070 ggml_tensor * v_mla,
2071 float kq_scale,
2072 int il) const {
2073 // these nodes are added to the graph together so that they are not reordered
2074 // by doing so, the number of splits in the graph is reduced
2075 ggml_build_forward_expand(gf, q_cur);
2076
2077 if (k_cur) {
2078 ggml_build_forward_expand(gf, k_cur);
2079 }
2080
2081 if (v_cur) {
2082 ggml_build_forward_expand(gf, v_cur);
2083 }
2084
2085 const auto * mctx_iswa = inp->mctx;
2086
2087 const bool is_swa = hparams.is_swa(il);
2088
2089 const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2090
2091 // optionally store to KV cache
2092 if (k_cur) {
2093 const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2094
2095 ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2096 }
2097
2098 if (v_cur) {
2099 const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2100
2101 ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2102 }
2103
2104 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2105
2106 ggml_tensor * q = q_cur;
2107 ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2108 ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2109
2110 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2111 cb(cur, "kqv_out", il);
2112
2113 if (wo) {
2114 cur = build_lora_mm(wo, cur);
2115 }
2116
2117 if (wo_b) {
2118 //cb(cur, "kqv_wo", il);
2119 }
2120
2121 if (wo_b) {
2122 cur = ggml_add(ctx0, cur, wo_b);
2123 }
2124
2125 return cur;
2126}
2127
2128llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2129 auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2130
2131 const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2132
2133 inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
2134 ggml_set_input(inp->cross_kq_mask);
2135
2136 inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
2137
2138 return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2139}
2140
2141ggml_tensor * llm_graph_context::build_attn(
2142 llm_graph_input_attn_cross * inp,
2143 ggml_tensor * wo,
2144 ggml_tensor * wo_b,
2145 ggml_tensor * q_cur,
2146 ggml_tensor * k_cur,
2147 ggml_tensor * v_cur,
2148 ggml_tensor * kq_b,
2149 ggml_tensor * sinks,
2150 ggml_tensor * v_mla,
2151 float kq_scale,
2152 int il) const {
2153 // these nodes are added to the graph together so that they are not reordered
2154 // by doing so, the number of splits in the graph is reduced
2155 ggml_build_forward_expand(gf, q_cur);
2156 ggml_build_forward_expand(gf, k_cur);
2157 ggml_build_forward_expand(gf, v_cur);
2158
2159 const auto & kq_mask = inp->get_kq_mask_cross();
2160
2161 ggml_tensor * q = q_cur;
2162 ggml_tensor * k = k_cur;
2163 ggml_tensor * v = v_cur;
2164
2165 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2166 cb(cur, "kqv_out", il);
2167
2168 if (wo) {
2169 cur = build_lora_mm(wo, cur);
2170 }
2171
2172 if (wo_b) {
2173 //cb(cur, "kqv_wo", il);
2174 }
2175
2176 if (wo_b) {
2177 cur = ggml_add(ctx0, cur, wo_b);
2178 }
2179
2180 return cur;
2181}
2182
2183// TODO: maybe separate the inner implementation into a separate function
2184// like with the non-sliding window equivalent
2185// once sliding-window hybrid caches are a thing.
2186llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2187 const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2188
2189 auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2190
2191 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
2192
2193 {
2194 const auto n_kv = mctx_cur->get_base()->get_n_kv();
2195
2196 inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2197 inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2198
2199 inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2200 ggml_set_input(inp->self_kq_mask);
2201 ggml_set_name(inp->self_kq_mask, "self_kq_mask");
2202
2203 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2204 ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
2205 }
2206
2207 {
2208 GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2209
2210 const auto n_kv = mctx_cur->get_swa()->get_n_kv();
2211
2212 inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2213 inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2214
2215 inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2216 ggml_set_input(inp->self_kq_mask_swa);
2217 ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
2218
2219 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
2220 ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
2221 }
2222
2223 return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2224}
2225
2226ggml_tensor * llm_graph_context::build_rs(
2227 ggml_tensor * s,
2228 ggml_tensor * state_copy_main,
2229 ggml_tensor * state_copy_extra,
2230 int32_t state_size,
2231 int32_t n_seqs,
2232 uint32_t n_rs,
2233 uint32_t rs_head,
2234 uint32_t rs_size,
2235 int32_t rs_zero,
2236 const llm_graph_get_rows_fn & get_state_rows) const {
2237
2238 ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
2239
2240 // Clear a single state which will then be copied to the other cleared states.
2241 // Note that this is a no-op when the view is zero-sized.
2242 ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2243 ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2244
2245 // copy states
2246 // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2247 // {state_size, rs_size} -> {state_size, n_seqs}
2248 ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2249 ggml_build_forward_expand(gf, output_states);
2250
2251 // copy extra states which won't be changed further (between n_seqs and n_rs)
2252 ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2253 ggml_build_forward_expand(gf,
2254 ggml_cpy(ctx0,
2255 states_extra,
2256 ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
2257
2258 return output_states;
2259}
2260
2261static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2262 ggml_context * ctx0,
2263 const llama_ubatch & ubatch,
2264 const llama_memory_recurrent_context * mctx_cur) {
2265
2266 auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2267
2268 const int64_t n_rs = mctx_cur->get_n_rs();
2269 const int64_t n_seqs = ubatch.n_seqs;
2270
2271 inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2272 ggml_set_input(inp->s_copy);
2273
2274 inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2275 inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2276
2277 inp->head = mctx_cur->get_head();
2278 inp->rs_z = mctx_cur->get_rs_z();
2279
2280 return inp;
2281}
2282
2283llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2284 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2285
2286 auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2287
2288 return (llm_graph_input_rs *) res->add_input(std::move(inp));
2289}
2290
2291ggml_tensor * llm_graph_context::build_rs(
2292 llm_graph_input_rs * inp,
2293 ggml_tensor * s,
2294 int32_t state_size,
2295 int32_t n_seqs,
2296 const llm_graph_get_rows_fn & get_state_rows) const {
2297 const auto * kv_state = inp->mctx;
2298
2299 return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2300 kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2301 get_state_rows);
2302}
2303
2304ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2305 llm_graph_input_rs * inp,
2306 const llama_ubatch & ubatch,
2307 int il) const {
2308 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2309
2310 const auto token_shift_count = hparams.token_shift_count;
2311
2312 const int64_t n_seqs = ubatch.n_seqs;
2313
2314 ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2315
2316 ggml_tensor * token_shift = build_rs(
2317 inp, token_shift_all,
2318 hparams.n_embd_r(), n_seqs);
2319
2320 token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2321
2322 return token_shift;
2323}
2324
2325ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2326 ggml_tensor * token_shift,
2327 const llama_ubatch & ubatch,
2328 int il) const {
2329 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2330
2331 const auto token_shift_count = hparams.token_shift_count;
2332 const auto n_embd = hparams.n_embd;
2333
2334 const int64_t n_seqs = ubatch.n_seqs;
2335
2336 const auto kv_head = mctx_cur->get_head();
2337
2338 return ggml_cpy(
2339 ctx0,
2340 ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2341 ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2342 );
2343}
2344
2345llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2346 const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2347
2348 auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2349 auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2350
2351 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2352
2353 return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2354}
2355
2356llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2357 const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2358
2359 auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2360 auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2361
2362 auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2363
2364 return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2365}
2366
2367llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2368 const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2369
2370 auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2371
2372 // build iswa attention input
2373 const auto * attn_ctx = mctx_cur->get_attn();
2374
2375 auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2376
2377 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
2378
2379 {
2380 const auto n_kv = attn_ctx->get_base()->get_n_kv();
2381
2382 inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2383 inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2384
2385 inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2386 ggml_set_input(inp_attn->self_kq_mask);
2387
2388 inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2389 }
2390
2391 {
2392 const auto n_kv = attn_ctx->get_swa()->get_n_kv();
2393
2394 inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2395 inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2396
2397 inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2398 ggml_set_input(inp_attn->self_kq_mask_swa);
2399
2400 inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2401 }
2402
2403 auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2404
2405 return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2406}
2407
2408void llm_graph_context::build_dense_out(
2409 ggml_tensor * dense_2,
2410 ggml_tensor * dense_3) const {
2411 if (!cparams.embeddings || !(dense_2 || dense_3)) {
2412 return;
2413 }
2414 ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2415 GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2416
2417 if (dense_2) {
2418 cur = ggml_mul_mat(ctx0, dense_2, cur);
2419 }
2420 if (dense_3) {
2421 cur = ggml_mul_mat(ctx0, dense_3, cur);
2422 }
2423 cb(cur, "result_embd_pooled", -1);
2424 res->t_embd_pooled = cur;
2425 ggml_build_forward_expand(gf, cur);
2426}
2427
2428
2429void llm_graph_context::build_pooling(
2430 ggml_tensor * cls,
2431 ggml_tensor * cls_b,
2432 ggml_tensor * cls_out,
2433 ggml_tensor * cls_out_b) const {
2434 if (!cparams.embeddings) {
2435 return;
2436 }
2437
2438 ggml_tensor * inp = res->t_embd;
2439
2440 //// find result_norm tensor for input
2441 //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2442 // inp = ggml_graph_node(gf, i);
2443 // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2444 // break;
2445 // }
2446
2447 // inp = nullptr;
2448 //}
2449
2450 GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2451
2452 ggml_tensor * cur;
2453
2454 switch (pooling_type) {
2455 case LLAMA_POOLING_TYPE_NONE:
2456 {
2457 cur = inp;
2458 } break;
2459 case LLAMA_POOLING_TYPE_MEAN:
2460 {
2461 ggml_tensor * inp_mean = build_inp_mean();
2462 cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2463 } break;
2464 case LLAMA_POOLING_TYPE_CLS:
2465 case LLAMA_POOLING_TYPE_LAST:
2466 {
2467 ggml_tensor * inp_cls = build_inp_cls();
2468 cur = ggml_get_rows(ctx0, inp, inp_cls);
2469 } break;
2470 case LLAMA_POOLING_TYPE_RANK:
2471 {
2472 ggml_tensor * inp_cls = build_inp_cls();
2473 cur = ggml_get_rows(ctx0, inp, inp_cls);
2474
2475 // classification head
2476 // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2477 if (cls) {
2478 cur = ggml_mul_mat(ctx0, cls, cur);
2479 if (cls_b) {
2480 cur = ggml_add(ctx0, cur, cls_b);
2481 }
2482 cur = ggml_tanh(ctx0, cur);
2483 }
2484
2485 // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2486 // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2487 // Single layer classification head (direct projection)
2488 // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
2489 if (cls_out) {
2490 cur = ggml_mul_mat(ctx0, cls_out, cur);
2491 if (cls_out_b) {
2492 cur = ggml_add(ctx0, cur, cls_out_b);
2493 }
2494 }
2495
2496 // softmax for qwen3 reranker
2497 if (arch == LLM_ARCH_QWEN3) {
2498 cur = ggml_soft_max(ctx0, cur);
2499 }
2500 } break;
2501 default:
2502 {
2503 GGML_ABORT("unknown pooling type");
2504 }
2505 }
2506
2507 cb(cur, "result_embd_pooled", -1);
2508 res->t_embd_pooled = cur;
2509
2510 ggml_build_forward_expand(gf, cur);
2511}
2512
2513void llm_graph_context::build_sampling() const {
2514 if (samplers.empty() || !res->t_logits) {
2515 return;
2516 }
2517
2518 std::array<ggml_tensor *, 2> outs;
2519 outs[0] = res->t_logits;
2520
2521 auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2522 res->add_input(std::move(inp_sampling));
2523
2524 std::map<llama_seq_id, int32_t> seq_to_logit_row;
2525 int32_t logit_row_idx = 0;
2526
2527 for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2528 if (ubatch.output[i]) {
2529 llama_seq_id seq_id = ubatch.seq_id[i][0];
2530 seq_to_logit_row[seq_id] = logit_row_idx;
2531 logit_row_idx++;
2532 }
2533 }
2534
2535 // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2536 GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2537
2538 // add a dummy row of logits
2539 // this trick makes the graph static, regardless of which samplers are activated
2540 // this is important in order to minimize graph reallocations
2541 ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2542
2543 for (const auto & [seq_id, sampler] : samplers) {
2544 const auto it = seq_to_logit_row.find(seq_id);
2545
2546 // inactive samplers always work on the first row
2547 const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2548 const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
2549
2550 ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2551 ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2552
2553 struct llama_sampler_data data = {
2554 /*.logits =*/ logits_seq,
2555 /*.probs =*/ nullptr,
2556 /*.sampled =*/ nullptr,
2557 /*.candidates =*/ nullptr,
2558 };
2559
2560 assert(sampler->iface->backend_apply);
2561 sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2562
2563 if (data.sampled != nullptr) {
2564 res->t_sampled[seq_id] = data.sampled;
2565 outs[1] = data.sampled;
2566 ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2567 }
2568
2569 if (data.probs != nullptr) {
2570 res->t_sampled_probs[seq_id] = data.probs;
2571 outs[1] = data.probs;
2572 ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2573 }
2574
2575 if (data.logits != nullptr) {
2576 res->t_sampled_logits[seq_id] = data.logits;
2577 outs[1] = data.logits;
2578 ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2579 }
2580
2581 if (data.candidates != nullptr) {
2582 res->t_candidates[seq_id] = data.candidates;
2583 outs[1] = data.candidates;
2584 ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2585 }
2586 }
2587
2588 // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2589 /*
2590 for (const auto & [seq_id, sampler] : samplers) {
2591 if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2592 ggml_tensor * selected_token = it->second;
2593 if (selected_token != nullptr) {
2594 llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2595 }
2596 }
2597 }
2598 */
2599}
2600
2601int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2602 // TODO move to hparams if a T5 variant appears that uses a different value
2603 const int64_t max_distance = 128;
2604
2605 if (bidirectional) {
2606 n_buckets >>= 1;
2607 }
2608
2609 const int64_t max_exact = n_buckets >> 1;
2610
2611 int32_t relative_position = x - y;
2612 int32_t relative_bucket = 0;
2613
2614 if (bidirectional) {
2615 relative_bucket += (relative_position > 0) * n_buckets;
2616 relative_position = std::abs(relative_position);
2617 } else {
2618 relative_position = -std::min<int32_t>(relative_position, 0);
2619 }
2620
2621 int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
2622 relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
2623 relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2624
2625 return relative_bucket;
2626}