1#pragma once
2
3#include "llama-arch.h"
4#include "llama-batch.h"
5#include "llama-hparams.h"
6#include "llama-adapter.h"
7
8#include <cstdint>
9#include <vector>
10#include <memory>
11#include <set>
12#include <functional>
13#include <map>
14
15struct ggml_cgraph;
16struct ggml_context;
17struct ggml_tensor;
18
19struct llama_cparams;
20
21struct llama_memory_context_i;
22
23class llama_kv_cache_context;
24class llama_kv_cache_iswa_context;
25class llama_memory_recurrent_context;
26class llama_memory_hybrid_context;
27class llama_memory_hybrid_iswa_context;
28
29// certain models (typically multi-modal) can produce different types of graphs
30enum llm_graph_type {
31 LLM_GRAPH_TYPE_DEFAULT,
32 LLM_GRAPH_TYPE_ENCODER,
33 LLM_GRAPH_TYPE_DECODER,
34};
35
36enum llm_ffn_op_type {
37 LLM_FFN_SILU,
38 LLM_FFN_GELU,
39 LLM_FFN_RELU,
40 LLM_FFN_RELU_SQR,
41 LLM_FFN_SWIGLU,
42 LLM_FFN_GEGLU,
43 LLM_FFN_REGLU,
44 LLM_FFN_SWIGLU_OAI_MOE,
45};
46
47enum llm_ffn_gate_type {
48 LLM_FFN_SEQ,
49 LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
50};
51
52enum llm_norm_type {
53 LLM_NORM,
54 LLM_NORM_RMS,
55 LLM_NORM_GROUP,
56};
57
58// TODO: tmp - need something better to pass the data from the encoder to the decoder
59struct llama_cross {
60 // the output embeddings from the encoder as a ggml tensor
61 // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
62 // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
63 //ggml_tensor * t_embd = nullptr;
64
65 int64_t n_embd = 0;
66 int64_t n_enc = 0;
67
68 // embeddings data copied to host memory (tmp)
69 std::vector<float> v_embd;
70
71 // needed to construct the cross-attention mask in the decoder
72 std::vector<std::set<llama_seq_id>> seq_ids_enc;
73};
74
75struct llm_graph_params;
76
77//
78// llm_graph_input
79//
80
81class llm_graph_input_i {
82public:
83 llm_graph_input_i() {
84 const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
85 debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
86 }
87
88 virtual ~llm_graph_input_i() = default;
89
90 virtual void set_input(const llama_ubatch * ubatch) = 0;
91
92 // return true if the resulting input tensors using the provided graph parameters would be
93 // the same as the previous input tensors that we have currently stored in the object
94 virtual bool can_reuse(const llm_graph_params & params) {
95 // returning false here by default will prevent from reusing the graph if the check
96 // for the input type has not been implemented yet
97 GGML_UNUSED(params);
98 return false;
99 }
100protected:
101 // env: LLAMA_GRAPH_INPUT_DEBUG
102 int debug = 0;
103};
104
105using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
106
107class llm_graph_input_embd : public llm_graph_input_i {
108public:
109 llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
110 virtual ~llm_graph_input_embd() = default;
111
112 void set_input(const llama_ubatch * ubatch) override;
113
114 bool can_reuse(const llm_graph_params & params) override;
115
116 ggml_tensor * tokens = nullptr; // I32 [n_batch]
117 ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
118
119 const int64_t n_embd = 0;
120};
121
122class llm_graph_input_pos : public llm_graph_input_i {
123public:
124 llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
125 virtual ~llm_graph_input_pos() = default;
126
127 void set_input(const llama_ubatch * ubatch) override;
128
129 bool can_reuse(const llm_graph_params & params) override;
130
131 ggml_tensor * pos = nullptr; // I32 [n_batch]
132
133 const uint32_t n_pos_per_embd = 1;
134};
135
136// temperature tuning, used by llama4
137class llm_graph_input_attn_temp : public llm_graph_input_i {
138public:
139 llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
140 : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
141 virtual ~llm_graph_input_attn_temp() = default;
142
143 void set_input(const llama_ubatch * ubatch) override;
144
145 ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
146
147 const uint32_t n_attn_temp_floor_scale;
148 const float f_attn_temp_scale;
149 const float f_attn_temp_offset;
150};
151
152class llm_graph_input_pos_bucket : public llm_graph_input_i {
153public:
154 llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
155 virtual ~llm_graph_input_pos_bucket() = default;
156
157 void set_input(const llama_ubatch * ubatch) override;
158
159 ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
160
161 const llama_hparams hparams;
162};
163
164class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
165public:
166 llm_graph_input_pos_bucket_kv(
167 const llama_hparams & hparams,
168 const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
169 virtual ~llm_graph_input_pos_bucket_kv() = default;
170
171 void set_input(const llama_ubatch * ubatch) override;
172
173 ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
174
175 const llama_hparams hparams;
176
177 const llama_kv_cache_context * mctx;
178};
179
180class llm_graph_input_out_ids : public llm_graph_input_i {
181public:
182 llm_graph_input_out_ids(
183 const llama_hparams & hparams,
184 const llama_cparams & cparams,
185 uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
186 virtual ~llm_graph_input_out_ids() = default;
187
188 void set_input(const llama_ubatch * ubatch) override;
189
190 bool can_reuse(const llm_graph_params & params) override;
191
192 ggml_tensor * out_ids; // I32 [n_outputs]
193
194 const llama_hparams hparams;
195 const llama_cparams cparams;
196
197 const uint32_t n_outputs;
198};
199
200class llm_graph_input_mean : public llm_graph_input_i {
201public:
202 llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
203 virtual ~llm_graph_input_mean() = default;
204
205 void set_input(const llama_ubatch * ubatch) override;
206
207 ggml_tensor * mean; // F32 [n_batch, n_batch]
208
209 const llama_cparams cparams;
210};
211
212class llm_graph_input_cls : public llm_graph_input_i {
213public:
214 llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
215 virtual ~llm_graph_input_cls() = default;
216
217 void set_input(const llama_ubatch * ubatch) override;
218
219 ggml_tensor * cls; // I32 [n_batch]
220
221 const llama_cparams cparams;
222 const llm_arch arch;
223};
224
225class llm_graph_input_rs : public llm_graph_input_i {
226public:
227 llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
228 virtual ~llm_graph_input_rs() = default;
229
230 void set_input(const llama_ubatch * ubatch) override;
231
232 bool can_reuse(const llm_graph_params & params) override;
233
234 ggml_tensor * s_copy; // I32 [n_rs]
235
236 // views of s_copy, computed once per graph
237 // and shared across layers which use build_rs
238 ggml_tensor * s_copy_main; // I32 [n_seqs]
239 ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
240
241 const llama_memory_recurrent_context * mctx;
242
243 // used in view offsets, need to match for valid graph reuse
244 uint32_t head;
245 int32_t rs_z;
246};
247
248class llm_graph_input_cross_embd : public llm_graph_input_i {
249public:
250 llm_graph_input_cross_embd(
251 const llama_cross * cross) : cross(cross) {}
252 virtual ~llm_graph_input_cross_embd() = default;
253
254 void set_input(const llama_ubatch * ubatch) override;
255
256 ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
257
258 const llama_cross * cross;
259};
260
261class llm_graph_input_attn_no_cache : public llm_graph_input_i {
262public:
263 llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
264 hparams(hparams),
265 cparams(cparams) {
266 }
267 ~llm_graph_input_attn_no_cache() = default;
268
269 void set_input(const llama_ubatch * ubatch) override;
270
271 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
272 ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
273
274 // n_tokens == n_batch
275 ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
276 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
277 ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
278 ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
279
280 const llama_hparams hparams;
281 const llama_cparams cparams;
282};
283
284class llm_graph_input_attn_kv : public llm_graph_input_i {
285public:
286 llm_graph_input_attn_kv(
287 const llama_hparams & hparams,
288 const llama_cparams & cparams,
289 const llama_kv_cache_context * mctx) :
290 hparams(hparams),
291 cparams(cparams),
292 mctx(mctx) {
293 }
294 ~llm_graph_input_attn_kv() = default;
295
296 void set_input(const llama_ubatch * ubatch) override;
297
298 bool can_reuse(const llm_graph_params & params) override;
299
300 ggml_tensor * get_k_idxs() const { return self_k_idxs; }
301 ggml_tensor * get_v_idxs() const { return self_v_idxs; }
302
303 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
304
305 ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
306 ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
307
308 ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
309 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
310
311 // note: these have to be copies because in order to be able to reuse a graph, its inputs
312 // need to carry these parameters with them. otherwise, they can point to freed
313 // llm_graph_params from a previous batch, causing stack-use-after-return
314 const llama_hparams hparams;
315 const llama_cparams cparams;
316
317 const llama_kv_cache_context * mctx;
318};
319
320// V-less input for the KV cache
321// ref: https://github.com/ggml-org/llama.cpp/pull/19067
322class llm_graph_input_attn_k : public llm_graph_input_i {
323public:
324 llm_graph_input_attn_k(
325 const llama_hparams & hparams,
326 const llama_cparams & cparams,
327 const llama_kv_cache_context * mctx) :
328 hparams(hparams),
329 cparams(cparams),
330 mctx(mctx) {
331 }
332 ~llm_graph_input_attn_k() = default;
333
334 void set_input(const llama_ubatch * ubatch) override;
335
336 bool can_reuse(const llm_graph_params & params) override;
337
338 ggml_tensor * get_k_idxs() const { return self_k_idxs; }
339
340 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
341
342 ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
343
344 ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
345 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
346
347 const llama_hparams hparams;
348 const llama_cparams cparams;
349
350 const llama_kv_cache_context * mctx;
351};
352
353class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
354public:
355 llm_graph_input_attn_kv_iswa(
356 const llama_hparams & hparams,
357 const llama_cparams & cparams,
358 const llama_kv_cache_iswa_context * mctx) :
359 hparams(hparams),
360 cparams(cparams),
361 mctx(mctx) {
362 }
363 ~llm_graph_input_attn_kv_iswa() = default;
364
365 void set_input(const llama_ubatch * ubatch) override;
366
367 bool can_reuse(const llm_graph_params & params) override;
368
369 ggml_tensor * get_k_idxs() const { return self_k_idxs; }
370 ggml_tensor * get_v_idxs() const { return self_v_idxs; }
371 ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
372 ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
373
374 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
375 ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
376
377 ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
378 ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
379 ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
380 ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
381
382 ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
383 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
384 ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
385 ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
386
387 const llama_hparams hparams;
388 const llama_cparams cparams;
389
390 const llama_kv_cache_iswa_context * mctx;
391};
392
393class llm_graph_input_attn_cross : public llm_graph_input_i {
394public:
395 llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
396 ~llm_graph_input_attn_cross() = default;
397
398 void set_input(const llama_ubatch * ubatch) override;
399
400 ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
401
402 ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
403 ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
404
405 const llama_cross * cross = nullptr;
406};
407
408class llm_graph_input_mem_hybrid : public llm_graph_input_i {
409public:
410 llm_graph_input_mem_hybrid(
411 const llama_cparams & cparams,
412 std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
413 std::unique_ptr<llm_graph_input_rs> inp_rs,
414 const llama_memory_hybrid_context * mctx) :
415 inp_attn(std::move(inp_attn)),
416 inp_rs(std::move(inp_rs)),
417 cparams(cparams),
418 mctx(mctx) { }
419 virtual ~llm_graph_input_mem_hybrid() = default;
420
421 void set_input(const llama_ubatch * ubatch) override;
422
423 bool can_reuse(const llm_graph_params & params) override;
424
425 std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
426 std::unique_ptr<llm_graph_input_rs> inp_rs;
427
428 llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
429 llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
430
431 const llama_cparams cparams;
432
433 const llama_memory_hybrid_context * mctx;
434};
435
436class llm_graph_input_mem_hybrid_k : public llm_graph_input_i {
437public:
438 llm_graph_input_mem_hybrid_k(
439 const llama_cparams & cparams,
440 std::unique_ptr<llm_graph_input_attn_k> inp_attn,
441 std::unique_ptr<llm_graph_input_rs> inp_rs,
442 const llama_memory_hybrid_context * mctx) :
443 inp_attn(std::move(inp_attn)),
444 inp_rs(std::move(inp_rs)),
445 cparams(cparams),
446 mctx(mctx) { }
447 virtual ~llm_graph_input_mem_hybrid_k() = default;
448
449 void set_input(const llama_ubatch * ubatch) override;
450
451 bool can_reuse(const llm_graph_params & params) override;
452
453 std::unique_ptr<llm_graph_input_attn_k> inp_attn;
454 std::unique_ptr<llm_graph_input_rs> inp_rs;
455
456 llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); }
457 llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
458
459 const llama_cparams cparams;
460
461 const llama_memory_hybrid_context * mctx;
462};
463
464class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
465public:
466 llm_graph_input_mem_hybrid_iswa(
467 const llama_cparams & cparams,
468 std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
469 std::unique_ptr<llm_graph_input_rs> inp_rs,
470 const llama_memory_hybrid_iswa_context * mctx) :
471 inp_attn(std::move(inp_attn)),
472 inp_rs(std::move(inp_rs)),
473 cparams(cparams),
474 mctx(mctx) { }
475 virtual ~llm_graph_input_mem_hybrid_iswa() = default;
476
477 void set_input(const llama_ubatch * ubatch) override;
478
479 bool can_reuse(const llm_graph_params & params) override;
480
481 std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
482 std::unique_ptr<llm_graph_input_rs> inp_rs;
483
484 llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
485 llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
486
487 const llama_cparams cparams;
488
489 const llama_memory_hybrid_iswa_context * mctx;
490};
491
492class llm_graph_input_sampling : public llm_graph_input_i {
493public:
494 llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
495 samplers(std::move(samplers)) { }
496 virtual ~llm_graph_input_sampling() = default;
497
498 void set_input(const llama_ubatch * ubatch) override;
499 bool can_reuse(const llm_graph_params & params) override;
500
501 std::map<llama_seq_id, llama_sampler *> samplers;
502};
503
504//
505// llm_graph_result
506//
507
508// these objects deliver the result from the graph build process back to the llama_context
509// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
510// specific data, by calling the set_inputs() method
511// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
512// these are used by the llama_context to extact the relevant data, based on the compute parameters
513
514// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
515using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
516
517class llm_graph_result;
518
519struct llm_graph_params {
520 llm_arch arch = LLM_ARCH_UNKNOWN;
521
522 llama_hparams hparams;
523 llama_cparams cparams;
524
525 llama_ubatch ubatch; // note: intentionally make a copy
526
527 llm_graph_type gtype;
528
529 ggml_backend_sched_t sched;
530 ggml_backend_t backend_cpu;
531
532 const llama_adapter_cvec * cvec;
533 const llama_adapter_loras * loras;
534 const llama_memory_context_i * mctx;
535 const llama_cross * cross;
536
537 std::map<llama_seq_id, llama_sampler *> samplers;
538
539 static bool samplers_equal(
540 const std::map<llama_seq_id, llama_sampler *> & lhs,
541 const std::map<llama_seq_id, llama_sampler *> & rhs) {
542 if (lhs.size() != rhs.size()) {
543 return false;
544 }
545 for (const auto & [seq_id, sampler] : lhs) {
546 auto it = rhs.find(seq_id);
547 if (it == rhs.end() || it->second != sampler) {
548 return false;
549 }
550 }
551 return true;
552 }
553
554 uint32_t n_outputs;
555
556 llm_graph_cb cb;
557
558 llm_graph_result * res;
559
560 // return true if the "other" params would result in a graph with the same topology as with the current params
561 // having the same topology allows us to reuse the graph in some cases
562 bool allow_reuse(const llm_graph_params & other) const {
563 // first check the ubatch
564 bool can_reuse_ubatch =
565 ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
566 ubatch.n_tokens == other.ubatch.n_tokens &&
567 ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
568 ubatch.n_seqs == other.ubatch.n_seqs &&
569 ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
570 (
571 (!ubatch.token && !other.ubatch.token) ||
572 (!ubatch.embd && !other.ubatch.embd)
573 );
574
575 // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
576 // the reason is because the set of attention streams would be different for different sequences
577 if (can_reuse_ubatch && ubatch.equal_seqs()) {
578 if (!ubatch.data) {
579 // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
580 // therefore we cannot perform the sequence id check. normally should never happen
581 can_reuse_ubatch = false;
582 } else {
583 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
584 can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
585 }
586 }
587 }
588
589 if (!can_reuse_ubatch) {
590 return false;
591 }
592
593 if (n_outputs != other.n_outputs) {
594 return false;
595 }
596
597 if (!samplers_equal(samplers, other.samplers)) {
598 return false;
599 }
600
601 if (samplers.size() > 0) {
602 if (!ubatch.data || !other.ubatch.data) {
603 return false;
604 }
605
606 // check that the outputs are the same for all samplers
607 for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
608 if (ubatch.output[i] != other.ubatch.output[i] ||
609 ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
610 return false;
611 }
612 }
613 }
614
615 return
616 cparams.embeddings == other.cparams.embeddings &&
617 cparams.causal_attn == other.cparams.causal_attn &&
618 arch == other.arch &&
619 gtype == other.gtype &&
620 cvec == other.cvec &&
621 loras == other.loras &&
622 cross == other.cross;
623 }
624};
625
626class llm_graph_result {
627public:
628 llm_graph_result(int64_t max_nodes);
629
630 virtual ~llm_graph_result() = default;
631
632 ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
633 ggml_tensor * get_logits() const { return t_logits; }
634 ggml_tensor * get_embd() const { return t_embd; }
635 ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
636
637 ggml_cgraph * get_gf() const { return gf; }
638 ggml_context * get_ctx() const { return ctx_compute.get(); }
639
640 int64_t get_max_nodes() const;
641
642 void reset();
643
644 void set_inputs(const llama_ubatch * ubatch);
645 void set_outputs();
646
647 // try to update the existing graph result using the new graph parameters in order to reuse it
648 // this can only be done if we determine that the resulting graph using the new graph parameters
649 // would be identical to the existing graph. in that case, we simply have to update the memory
650 // contexts of the input tensors of the graph and we can reuse it for another computation
651 // return true if the graph was updated and can be reused
652 bool can_reuse(const llm_graph_params & params);
653
654 llm_graph_input_i * add_input(llm_graph_input_ptr input);
655
656 void set_params(const llm_graph_params & params);
657
658 // important graph nodes
659 ggml_tensor * t_inp_tokens = nullptr;
660 ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
661 ggml_tensor * t_logits = nullptr;
662 ggml_tensor * t_embd = nullptr;
663 ggml_tensor * t_embd_pooled = nullptr;
664
665 std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
666 std::map<llama_seq_id, ggml_tensor*> t_candidates;
667 std::map<llama_seq_id, ggml_tensor*> t_sampled;
668 std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
669
670 std::vector<llm_graph_input_ptr> inputs;
671
672 ggml_context_ptr ctx_compute;
673
674 // memory buffers used to evaluate the model
675 std::vector<uint8_t> buf_compute_meta;
676
677 ggml_cgraph * gf;
678
679 int64_t max_nodes;
680
681private:
682 // keep a copy of the previous graph parameters
683 // we will use this to determine whether the graph can be reused by comparing them with the new parameters
684 // note: these are updated after constructing the new graph
685 llm_graph_params params;
686
687 // env: LLAMA_GRAPH_RESULT_DEBUG
688 int debug = 0;
689};
690
691using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
692
693//
694// llm_graph_context
695//
696
697// used in build_rs to properly order writes and avoid unnecessary copies
698using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
699
700struct llm_graph_context {
701 const llm_arch arch;
702
703 const llama_hparams & hparams;
704 const llama_cparams & cparams;
705 const llama_ubatch & ubatch;
706
707 const int64_t n_embd;
708 const int64_t n_layer;
709 const int64_t n_rot;
710 const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
711 const int64_t n_head;
712 const int64_t n_head_kv;
713 const int64_t n_embd_head_k;
714 const int64_t n_embd_k_gqa;
715 const int64_t n_embd_head_v;
716 const int64_t n_embd_v_gqa;
717 const int64_t n_expert;
718 const int64_t n_expert_used;
719
720 const float freq_base;
721 const float freq_scale;
722 const float ext_factor;
723 const float attn_factor;
724 const float beta_fast;
725 const float beta_slow;
726 const float norm_eps;
727 const float norm_rms_eps;
728
729 const int64_t n_tokens;
730 const int64_t n_outputs;
731 const int32_t n_ctx_orig; // yarn
732
733 const enum llama_pooling_type pooling_type;
734 const enum llama_rope_type rope_type;
735
736 ggml_backend_sched_t sched;
737
738 ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
739
740 const llama_adapter_cvec * cvec;
741 const llama_adapter_loras * loras;
742 const llama_memory_context_i * mctx;
743 const llama_cross * cross;
744
745 std::map<llama_seq_id, llama_sampler *> samplers;
746
747 const llm_graph_cb & cb_func;
748
749 llm_graph_result * res;
750
751 ggml_context * ctx0 = nullptr;
752 ggml_cgraph * gf = nullptr;
753
754 llm_graph_context(const llm_graph_params & params);
755 virtual ~llm_graph_context() = default;
756
757 void cb(ggml_tensor * cur, const char * name, int il) const;
758
759 //
760 // common
761 //
762
763 ggml_tensor * build_cvec(
764 ggml_tensor * cur,
765 int il) const;
766
767 // do mat_mul, while optionally apply lora
768 ggml_tensor * build_lora_mm(
769 ggml_tensor * w,
770 ggml_tensor * cur) const;
771
772 // do mat_mul_id, while optionally apply lora
773 ggml_tensor * build_lora_mm_id(
774 ggml_tensor * w, // ggml_tensor * as
775 ggml_tensor * cur, // ggml_tensor * b
776 ggml_tensor * ids) const;
777
778 ggml_tensor * build_norm(
779 ggml_tensor * cur,
780 ggml_tensor * mw,
781 ggml_tensor * mb,
782 llm_norm_type type,
783 int il) const;
784
785 ggml_tensor * build_ffn(
786 ggml_tensor * cur,
787 ggml_tensor * up,
788 ggml_tensor * up_b,
789 ggml_tensor * up_s,
790 ggml_tensor * gate,
791 ggml_tensor * gate_b,
792 ggml_tensor * gate_s,
793 ggml_tensor * down,
794 ggml_tensor * down_b,
795 ggml_tensor * down_s,
796 ggml_tensor * act_scales,
797 llm_ffn_op_type type_op,
798 llm_ffn_gate_type type_gate,
799 int il) const;
800
801 // build MoE FFN without bias tensors
802 ggml_tensor * build_moe_ffn(
803 ggml_tensor * cur,
804 ggml_tensor * gate_inp,
805 ggml_tensor * up_exps,
806 ggml_tensor * gate_exps,
807 ggml_tensor * down_exps,
808 ggml_tensor * exp_probs_b,
809 int64_t n_expert,
810 int64_t n_expert_used,
811 llm_ffn_op_type type_op,
812 bool norm_w,
813 bool scale_w,
814 float w_scale,
815 llama_expert_gating_func_type gating_op,
816 int il,
817 ggml_tensor * probs_in = nullptr) const;
818
819 ggml_tensor * build_moe_ffn(
820 ggml_tensor * cur,
821 ggml_tensor * gate_inp,
822 ggml_tensor * gate_inp_b,
823 ggml_tensor * up_exps,
824 ggml_tensor * up_exps_b,
825 ggml_tensor * gate_exps,
826 ggml_tensor * gate_exps_b,
827 ggml_tensor * down_exps,
828 ggml_tensor * down_exps_b,
829 ggml_tensor * exp_probs_b,
830 int64_t n_expert,
831 int64_t n_expert_used,
832 llm_ffn_op_type type_op,
833 bool norm_w,
834 bool scale_w,
835 float w_scale,
836 llama_expert_gating_func_type gating_op,
837 int il,
838 ggml_tensor * probs_in = nullptr) const;
839
840 //
841 // inputs
842 //
843
844 ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
845 ggml_tensor * build_inp_pos() const;
846 ggml_tensor * build_inp_attn_scale() const;
847 ggml_tensor * build_inp_out_ids() const;
848 ggml_tensor * build_inp_mean() const;
849 ggml_tensor * build_inp_cls() const;
850
851 ggml_tensor * build_inp_cross_embd() const;
852 ggml_tensor * build_inp_pos_bucket_enc() const;
853 ggml_tensor * build_inp_pos_bucket_dec() const;
854 ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
855
856 //
857 // attention
858 //
859
860 ggml_tensor * build_attn_mha(
861 ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
862 ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
863 ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
864 ggml_tensor * kq_b,
865 ggml_tensor * kq_mask,
866 ggml_tensor * sinks, // [n_head_q]
867 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
868 float kq_scale,
869 int il) const;
870
871 llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
872
873 ggml_tensor * build_attn(
874 llm_graph_input_attn_no_cache * inp,
875 ggml_tensor * wo,
876 ggml_tensor * wo_b,
877 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
878 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
879 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
880 ggml_tensor * kq_b,
881 ggml_tensor * sinks, // [n_head_q]
882 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
883 float kq_scale,
884 int il) const;
885
886 llm_graph_input_attn_kv * build_attn_inp_kv() const;
887
888 ggml_tensor * build_attn(
889 llm_graph_input_attn_kv * inp,
890 ggml_tensor * wo,
891 ggml_tensor * wo_b,
892 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
893 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
894 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
895 ggml_tensor * kq_b,
896 ggml_tensor * sinks, // [n_head_q]
897 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
898 float kq_scale,
899 int il) const;
900
901 llm_graph_input_attn_k * build_attn_inp_k() const;
902
903 ggml_tensor * build_attn(
904 llm_graph_input_attn_k * inp,
905 ggml_tensor * wo,
906 ggml_tensor * wo_b,
907 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
908 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
909 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
910 ggml_tensor * kq_b,
911 ggml_tensor * sinks, // [n_head_q]
912 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
913 float kq_scale,
914 int il) const;
915
916 llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
917
918 // note: if k_cur or v_cur are not provided, they will not be stored in the memory
919 ggml_tensor * build_attn(
920 llm_graph_input_attn_kv_iswa * inp,
921 ggml_tensor * wo,
922 ggml_tensor * wo_b,
923 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
924 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
925 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
926 ggml_tensor * kq_b,
927 ggml_tensor * sinks, // [n_head_q]
928 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
929 float kq_scale,
930 int il) const;
931
932 llm_graph_input_attn_cross * build_attn_inp_cross() const;
933
934 ggml_tensor * build_attn(
935 llm_graph_input_attn_cross * inp,
936 ggml_tensor * wo,
937 ggml_tensor * wo_b,
938 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
939 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
940 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
941 ggml_tensor * kq_b,
942 ggml_tensor * sinks, // [n_head_q]
943 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
944 float kq_scale,
945 int il) const;
946
947 //
948 // recurrent
949 //
950
951 // TODO: move this implementation to llama_memory_recurrent.
952 // this is analogous to llama_kv_cache::cpy_k / cpy_v
953 // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
954 // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
955 // `llama_memory_recurrent`
956 ggml_tensor * build_rs(
957 ggml_tensor * s,
958 ggml_tensor * state_copy_main,
959 ggml_tensor * state_copy_extra,
960 int32_t state_size,
961 int32_t n_seqs,
962 uint32_t n_rs,
963 uint32_t rs_head,
964 uint32_t rs_size,
965 int32_t rs_zero,
966 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
967
968 llm_graph_input_rs * build_rs_inp() const;
969
970 ggml_tensor * build_rs(
971 llm_graph_input_rs * inp,
972 ggml_tensor * s,
973 int32_t state_size,
974 int32_t n_seqs,
975 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
976
977 ggml_tensor * build_rwkv_token_shift_load(
978 llm_graph_input_rs * inp,
979 const llama_ubatch & ubatch,
980 int il) const;
981
982 ggml_tensor * build_rwkv_token_shift_store(
983 ggml_tensor * token_shift,
984 const llama_ubatch & ubatch,
985 int il) const;
986 //
987 // hybrid
988 //
989
990 llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
991 llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const;
992
993 llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
994
995 //
996 // pooling
997 //
998
999 void build_pooling(
1000 ggml_tensor * cls,
1001 ggml_tensor * cls_b,
1002 ggml_tensor * cls_out,
1003 ggml_tensor * cls_out_b) const;
1004
1005 //
1006 // sampling (backend sampling)
1007 //
1008
1009 void build_sampling() const;
1010
1011 //
1012 // dense (out)
1013 //
1014
1015 void build_dense_out(
1016 ggml_tensor * dense_2,
1017 ggml_tensor * dense_3) const;
1018};
1019
1020// TODO: better name
1021int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);