1#pragma once
   2
   3#include "llama-arch.h"
   4#include "llama-batch.h"
   5#include "llama-hparams.h"
   6#include "llama-adapter.h"
   7
   8#include <cstdint>
   9#include <vector>
  10#include <memory>
  11#include <set>
  12#include <functional>
  13#include <map>
  14
  15struct ggml_cgraph;
  16struct ggml_context;
  17struct ggml_tensor;
  18
  19struct llama_cparams;
  20
  21struct llama_memory_context_i;
  22
  23class llama_kv_cache_context;
  24class llama_kv_cache_iswa_context;
  25class llama_memory_recurrent_context;
  26class llama_memory_hybrid_context;
  27class llama_memory_hybrid_iswa_context;
  28
  29// certain models (typically multi-modal) can produce different types of graphs
  30enum llm_graph_type {
  31    LLM_GRAPH_TYPE_DEFAULT,
  32    LLM_GRAPH_TYPE_ENCODER,
  33    LLM_GRAPH_TYPE_DECODER,
  34};
  35
  36enum llm_ffn_op_type {
  37    LLM_FFN_SILU,
  38    LLM_FFN_GELU,
  39    LLM_FFN_RELU,
  40    LLM_FFN_RELU_SQR,
  41    LLM_FFN_SWIGLU,
  42    LLM_FFN_GEGLU,
  43    LLM_FFN_REGLU,
  44    LLM_FFN_SWIGLU_OAI_MOE,
  45};
  46
  47enum llm_ffn_gate_type {
  48    LLM_FFN_SEQ,
  49    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  50};
  51
  52enum llm_norm_type {
  53    LLM_NORM,
  54    LLM_NORM_RMS,
  55    LLM_NORM_GROUP,
  56};
  57
  58// TODO: tmp - need something better to pass the data from the encoder to the decoder
  59struct llama_cross {
  60    // the output embeddings from the encoder as a ggml tensor
  61    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  62    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  63    //ggml_tensor * t_embd = nullptr;
  64
  65    int64_t n_embd = 0;
  66    int64_t n_enc  = 0;
  67
  68    // embeddings data copied to host memory (tmp)
  69    std::vector<float> v_embd;
  70
  71    // needed to construct the cross-attention mask in the decoder
  72    std::vector<std::set<llama_seq_id>> seq_ids_enc;
  73};
  74
  75struct llm_graph_params;
  76
  77//
  78// llm_graph_input
  79//
  80
  81class llm_graph_input_i {
  82public:
  83    llm_graph_input_i() {
  84        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
  85        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
  86    }
  87
  88    virtual ~llm_graph_input_i() = default;
  89
  90    virtual void set_input(const llama_ubatch * ubatch) = 0;
  91
  92    // return true if the resulting input tensors using the provided graph parameters would be
  93    //   the same as the previous input tensors that we have currently stored in the object
  94    virtual bool can_reuse(const llm_graph_params & params) {
  95        // returning false here by default will prevent from reusing the graph if the check
  96        //   for the input type has not been implemented yet
  97        GGML_UNUSED(params);
  98        return false;
  99    }
 100protected:
 101    // env: LLAMA_GRAPH_INPUT_DEBUG
 102    int debug = 0;
 103};
 104
 105using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
 106
 107class llm_graph_input_embd : public llm_graph_input_i {
 108public:
 109    llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
 110    virtual ~llm_graph_input_embd() = default;
 111
 112    void set_input(const llama_ubatch * ubatch) override;
 113
 114    bool can_reuse(const llm_graph_params & params) override;
 115
 116    ggml_tensor * tokens = nullptr; // I32 [n_batch]
 117    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
 118
 119    const int64_t n_embd = 0;
 120};
 121
 122class llm_graph_input_pos : public llm_graph_input_i {
 123public:
 124    llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
 125    virtual ~llm_graph_input_pos() = default;
 126
 127    void set_input(const llama_ubatch * ubatch) override;
 128
 129    bool can_reuse(const llm_graph_params & params) override;
 130
 131    ggml_tensor * pos = nullptr; // I32 [n_batch]
 132
 133    const uint32_t n_pos_per_embd = 1;
 134};
 135
 136// temperature tuning, used by llama4
 137class llm_graph_input_attn_temp : public llm_graph_input_i {
 138public:
 139    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
 140        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
 141    virtual ~llm_graph_input_attn_temp() = default;
 142
 143    void set_input(const llama_ubatch * ubatch) override;
 144
 145    ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
 146
 147    const uint32_t n_attn_temp_floor_scale;
 148    const float    f_attn_temp_scale;
 149    const float    f_attn_temp_offset;
 150};
 151
 152class llm_graph_input_pos_bucket : public llm_graph_input_i {
 153public:
 154    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
 155    virtual ~llm_graph_input_pos_bucket() = default;
 156
 157    void set_input(const llama_ubatch * ubatch) override;
 158
 159    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
 160
 161    const llama_hparams hparams;
 162};
 163
 164class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 165public:
 166    llm_graph_input_pos_bucket_kv(
 167            const llama_hparams & hparams,
 168            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
 169    virtual ~llm_graph_input_pos_bucket_kv() = default;
 170
 171    void set_input(const llama_ubatch * ubatch) override;
 172
 173    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
 174
 175    const llama_hparams hparams;
 176
 177    const llama_kv_cache_context * mctx;
 178};
 179
 180class llm_graph_input_out_ids : public llm_graph_input_i {
 181public:
 182    llm_graph_input_out_ids(
 183            const llama_hparams & hparams,
 184            const llama_cparams & cparams,
 185            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
 186    virtual ~llm_graph_input_out_ids() = default;
 187
 188    void set_input(const llama_ubatch * ubatch) override;
 189
 190    bool can_reuse(const llm_graph_params & params) override;
 191
 192    ggml_tensor * out_ids; // I32 [n_outputs]
 193
 194    const llama_hparams hparams;
 195    const llama_cparams cparams;
 196
 197    const uint32_t n_outputs;
 198};
 199
 200class llm_graph_input_mean : public llm_graph_input_i {
 201public:
 202    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
 203    virtual ~llm_graph_input_mean() = default;
 204
 205    void set_input(const llama_ubatch * ubatch) override;
 206
 207    ggml_tensor * mean; // F32 [n_batch, n_batch]
 208
 209    const llama_cparams cparams;
 210};
 211
 212class llm_graph_input_cls : public llm_graph_input_i {
 213public:
 214    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
 215    virtual ~llm_graph_input_cls() = default;
 216
 217    void set_input(const llama_ubatch * ubatch) override;
 218
 219    ggml_tensor * cls; // I32 [n_batch]
 220
 221    const llama_cparams cparams;
 222    const llm_arch arch;
 223};
 224
 225class llm_graph_input_rs : public llm_graph_input_i {
 226public:
 227    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
 228    virtual ~llm_graph_input_rs() = default;
 229
 230    void set_input(const llama_ubatch * ubatch) override;
 231
 232    bool can_reuse(const llm_graph_params & params) override;
 233
 234    ggml_tensor * s_copy;  // I32 [n_rs]
 235
 236    // views of s_copy, computed once per graph
 237    // and shared across layers which use build_rs
 238    ggml_tensor * s_copy_main;   // I32 [n_seqs]
 239    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
 240
 241    const llama_memory_recurrent_context * mctx;
 242
 243    // used in view offsets, need to match for valid graph reuse
 244    uint32_t head;
 245    int32_t rs_z;
 246};
 247
 248class llm_graph_input_cross_embd : public llm_graph_input_i {
 249public:
 250    llm_graph_input_cross_embd(
 251            const llama_cross * cross) : cross(cross) {}
 252    virtual ~llm_graph_input_cross_embd() = default;
 253
 254    void set_input(const llama_ubatch * ubatch) override;
 255
 256    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
 257
 258    const llama_cross * cross;
 259};
 260
 261class llm_graph_input_attn_no_cache : public llm_graph_input_i {
 262public:
 263    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
 264        hparams(hparams),
 265        cparams(cparams) {
 266    }
 267    ~llm_graph_input_attn_no_cache() = default;
 268
 269    void set_input(const llama_ubatch * ubatch) override;
 270
 271    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
 272    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 273
 274    // n_tokens == n_batch
 275    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
 276    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
 277    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
 278    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_tokens, n_batch/n_stream, 1, n_stream]
 279
 280    const llama_hparams hparams;
 281    const llama_cparams cparams;
 282};
 283
 284class llm_graph_input_attn_kv : public llm_graph_input_i {
 285public:
 286    llm_graph_input_attn_kv(
 287            const llama_hparams & hparams,
 288            const llama_cparams & cparams,
 289            const llama_kv_cache_context * mctx) :
 290        hparams(hparams),
 291        cparams(cparams),
 292        mctx(mctx) {
 293    }
 294    ~llm_graph_input_attn_kv() = default;
 295
 296    void set_input(const llama_ubatch * ubatch) override;
 297
 298    bool can_reuse(const llm_graph_params & params) override;
 299
 300    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
 301    ggml_tensor * get_v_idxs() const { return self_v_idxs; }
 302
 303    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 304
 305    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
 306    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 307
 308    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
 309    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 310
 311    // note: these have to be copies because in order to be able to reuse a graph, its inputs
 312    //       need to carry these parameters with them. otherwise, they can point to freed
 313    //       llm_graph_params from a previous batch, causing stack-use-after-return
 314    const llama_hparams hparams;
 315    const llama_cparams cparams;
 316
 317    const llama_kv_cache_context * mctx;
 318};
 319
 320// V-less input for the KV cache
 321// ref: https://github.com/ggml-org/llama.cpp/pull/19067
 322class llm_graph_input_attn_k : public llm_graph_input_i {
 323public:
 324    llm_graph_input_attn_k(
 325            const llama_hparams & hparams,
 326            const llama_cparams & cparams,
 327            const llama_kv_cache_context * mctx) :
 328        hparams(hparams),
 329        cparams(cparams),
 330        mctx(mctx) {
 331    }
 332    ~llm_graph_input_attn_k() = default;
 333
 334    void set_input(const llama_ubatch * ubatch) override;
 335
 336    bool can_reuse(const llm_graph_params & params) override;
 337
 338    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
 339
 340    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 341
 342    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
 343
 344    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
 345    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 346
 347    const llama_hparams hparams;
 348    const llama_cparams cparams;
 349
 350    const llama_kv_cache_context * mctx;
 351};
 352
 353class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
 354public:
 355    llm_graph_input_attn_kv_iswa(
 356            const llama_hparams & hparams,
 357            const llama_cparams & cparams,
 358            const llama_kv_cache_iswa_context * mctx) :
 359        hparams(hparams),
 360        cparams(cparams),
 361        mctx(mctx) {
 362    }
 363    ~llm_graph_input_attn_kv_iswa() = default;
 364
 365    void set_input(const llama_ubatch * ubatch) override;
 366
 367    bool can_reuse(const llm_graph_params & params) override;
 368
 369    ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
 370    ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
 371    ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
 372    ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
 373
 374    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
 375    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 376
 377    ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
 378    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 379    ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
 380    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 381
 382    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
 383    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 384    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
 385    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 386
 387    const llama_hparams hparams;
 388    const llama_cparams cparams;
 389
 390    const llama_kv_cache_iswa_context * mctx;
 391};
 392
 393class llm_graph_input_attn_cross : public llm_graph_input_i {
 394public:
 395    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
 396    ~llm_graph_input_attn_cross() = default;
 397
 398    void set_input(const llama_ubatch * ubatch) override;
 399
 400    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
 401
 402    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
 403    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
 404
 405    const llama_cross * cross = nullptr;
 406};
 407
 408class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 409public:
 410    llm_graph_input_mem_hybrid(
 411            const llama_cparams & cparams,
 412            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
 413            std::unique_ptr<llm_graph_input_rs>      inp_rs,
 414            const llama_memory_hybrid_context *      mctx) :
 415        inp_attn(std::move(inp_attn)),
 416        inp_rs(std::move(inp_rs)),
 417        cparams(cparams),
 418        mctx(mctx) { }
 419    virtual ~llm_graph_input_mem_hybrid() = default;
 420
 421    void set_input(const llama_ubatch * ubatch) override;
 422
 423    bool can_reuse(const llm_graph_params & params) override;
 424
 425    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
 426    std::unique_ptr<llm_graph_input_rs>      inp_rs;
 427
 428    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
 429    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 430
 431    const llama_cparams cparams;
 432
 433    const llama_memory_hybrid_context * mctx;
 434};
 435
 436class llm_graph_input_mem_hybrid_k : public llm_graph_input_i {
 437public:
 438    llm_graph_input_mem_hybrid_k(
 439            const llama_cparams & cparams,
 440            std::unique_ptr<llm_graph_input_attn_k> inp_attn,
 441            std::unique_ptr<llm_graph_input_rs>      inp_rs,
 442            const llama_memory_hybrid_context *      mctx) :
 443        inp_attn(std::move(inp_attn)),
 444        inp_rs(std::move(inp_rs)),
 445        cparams(cparams),
 446        mctx(mctx) { }
 447    virtual ~llm_graph_input_mem_hybrid_k() = default;
 448
 449    void set_input(const llama_ubatch * ubatch) override;
 450
 451    bool can_reuse(const llm_graph_params & params) override;
 452
 453    std::unique_ptr<llm_graph_input_attn_k> inp_attn;
 454    std::unique_ptr<llm_graph_input_rs>      inp_rs;
 455
 456    llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); }
 457    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 458
 459    const llama_cparams cparams;
 460
 461    const llama_memory_hybrid_context * mctx;
 462};
 463
 464class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
 465public:
 466    llm_graph_input_mem_hybrid_iswa(
 467            const llama_cparams & cparams,
 468            std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
 469            std::unique_ptr<llm_graph_input_rs>          inp_rs,
 470            const llama_memory_hybrid_iswa_context *     mctx) :
 471        inp_attn(std::move(inp_attn)),
 472        inp_rs(std::move(inp_rs)),
 473        cparams(cparams),
 474        mctx(mctx) { }
 475    virtual ~llm_graph_input_mem_hybrid_iswa() = default;
 476
 477    void set_input(const llama_ubatch * ubatch) override;
 478
 479    bool can_reuse(const llm_graph_params & params) override;
 480
 481    std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
 482    std::unique_ptr<llm_graph_input_rs>          inp_rs;
 483
 484    llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
 485    llm_graph_input_rs           * get_recr() const { return inp_rs.get(); }
 486
 487    const llama_cparams cparams;
 488
 489    const llama_memory_hybrid_iswa_context * mctx;
 490};
 491
 492class llm_graph_input_sampling : public llm_graph_input_i {
 493public:
 494    llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
 495        samplers(std::move(samplers)) { }
 496    virtual ~llm_graph_input_sampling() = default;
 497
 498    void set_input(const llama_ubatch * ubatch) override;
 499    bool can_reuse(const llm_graph_params & params) override;
 500
 501    std::map<llama_seq_id, llama_sampler *> samplers;
 502};
 503
 504//
 505// llm_graph_result
 506//
 507
 508// these objects deliver the result from the graph build process back to the llama_context
 509// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
 510//   specific data, by calling the set_inputs() method
 511// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
 512//   these are used by the llama_context to extact the relevant data, based on the compute parameters
 513
 514// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
 515using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
 516
 517class llm_graph_result;
 518
 519struct llm_graph_params {
 520    llm_arch arch = LLM_ARCH_UNKNOWN;
 521
 522    llama_hparams hparams;
 523    llama_cparams cparams;
 524
 525    llama_ubatch ubatch; // note: intentionally make a copy
 526
 527    llm_graph_type gtype;
 528
 529    ggml_backend_sched_t sched;
 530    ggml_backend_t backend_cpu;
 531
 532    const llama_adapter_cvec     * cvec;
 533    const llama_adapter_loras    * loras;
 534    const llama_memory_context_i * mctx;
 535    const llama_cross            * cross;
 536
 537    std::map<llama_seq_id, llama_sampler *> samplers;
 538
 539    static bool samplers_equal(
 540          const std::map<llama_seq_id, llama_sampler *> & lhs,
 541          const std::map<llama_seq_id, llama_sampler *> & rhs) {
 542        if (lhs.size() != rhs.size()) {
 543            return false;
 544        }
 545        for (const auto & [seq_id, sampler] : lhs) {
 546            auto it = rhs.find(seq_id);
 547            if (it == rhs.end() || it->second != sampler) {
 548                return false;
 549            }
 550        }
 551        return true;
 552    }
 553
 554    uint32_t n_outputs;
 555
 556    llm_graph_cb cb;
 557
 558    llm_graph_result * res;
 559
 560    // return true if the "other" params would result in a graph with the same topology as with the current params
 561    //   having the same topology allows us to reuse the graph in some cases
 562    bool allow_reuse(const llm_graph_params & other) const {
 563        // first check the ubatch
 564        bool can_reuse_ubatch =
 565            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
 566            ubatch.n_tokens     == other.ubatch.n_tokens &&
 567            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
 568            ubatch.n_seqs       == other.ubatch.n_seqs &&
 569            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
 570            (
 571                (!ubatch.token && !other.ubatch.token) ||
 572                (!ubatch.embd  && !other.ubatch.embd)
 573            );
 574
 575        // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
 576        //   the reason is because the set of attention streams would be different for different sequences
 577        if (can_reuse_ubatch && ubatch.equal_seqs()) {
 578            if (!ubatch.data) {
 579                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
 580                //   therefore we cannot perform the sequence id check. normally should never happen
 581                can_reuse_ubatch = false;
 582            } else {
 583                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
 584                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
 585                }
 586            }
 587        }
 588
 589        if (!can_reuse_ubatch) {
 590            return false;
 591        }
 592
 593        if (n_outputs != other.n_outputs) {
 594            return false;
 595        }
 596
 597        if (!samplers_equal(samplers, other.samplers)) {
 598            return false;
 599        }
 600
 601        if (samplers.size() > 0) {
 602            if (!ubatch.data || !other.ubatch.data) {
 603                return false;
 604            }
 605
 606            // check that the outputs are the same for all samplers
 607            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
 608                if (ubatch.output[i]    != other.ubatch.output[i] ||
 609                    ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
 610                    return false;
 611                }
 612            }
 613        }
 614
 615        return
 616            cparams.embeddings  == other.cparams.embeddings  &&
 617            cparams.causal_attn == other.cparams.causal_attn &&
 618            arch  == other.arch  &&
 619            gtype == other.gtype &&
 620            cvec  == other.cvec  &&
 621            loras == other.loras &&
 622            cross == other.cross;
 623    }
 624};
 625
 626class llm_graph_result {
 627public:
 628    llm_graph_result(int64_t max_nodes);
 629
 630    virtual ~llm_graph_result() = default;
 631
 632    ggml_tensor * get_inp_tokens()  const { return t_inp_tokens; }
 633    ggml_tensor * get_logits()      const { return t_logits; }
 634    ggml_tensor * get_embd()        const { return t_embd; }
 635    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
 636
 637    ggml_cgraph  * get_gf()  const { return gf; }
 638    ggml_context * get_ctx() const { return ctx_compute.get(); }
 639
 640    int64_t get_max_nodes() const;
 641
 642    void reset();
 643
 644    void set_inputs(const llama_ubatch * ubatch);
 645    void set_outputs();
 646
 647    // try to update the existing graph result using the new graph parameters in order to reuse it
 648    // this can only be done if we determine that the resulting graph using the new graph parameters
 649    //   would be identical to the existing graph. in that case, we simply have to update the memory
 650    //   contexts of the input tensors of the graph and we can reuse it for another computation
 651    // return true if the graph was updated and can be reused
 652    bool can_reuse(const llm_graph_params & params);
 653
 654    llm_graph_input_i * add_input(llm_graph_input_ptr input);
 655
 656    void set_params(const llm_graph_params & params);
 657
 658    // important graph nodes
 659    ggml_tensor * t_inp_tokens  = nullptr;
 660    ggml_tensor * t_inp_embd    = nullptr; // [n_embd_inp, n_tokens]
 661    ggml_tensor * t_logits      = nullptr;
 662    ggml_tensor * t_embd        = nullptr;
 663    ggml_tensor * t_embd_pooled = nullptr;
 664
 665    std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
 666    std::map<llama_seq_id, ggml_tensor*> t_candidates;
 667    std::map<llama_seq_id, ggml_tensor*> t_sampled;
 668    std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
 669
 670    std::vector<llm_graph_input_ptr> inputs;
 671
 672    ggml_context_ptr ctx_compute;
 673
 674    // memory buffers used to evaluate the model
 675    std::vector<uint8_t> buf_compute_meta;
 676
 677    ggml_cgraph * gf;
 678
 679    int64_t max_nodes;
 680
 681private:
 682    // keep a copy of the previous graph parameters
 683    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
 684    // note: these are updated after constructing the new graph
 685    llm_graph_params params;
 686
 687    // env: LLAMA_GRAPH_RESULT_DEBUG
 688    int debug = 0;
 689};
 690
 691using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
 692
 693//
 694// llm_graph_context
 695//
 696
 697// used in build_rs to properly order writes and avoid unnecessary copies
 698using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
 699
 700struct llm_graph_context {
 701    const llm_arch arch;
 702
 703    const llama_hparams & hparams;
 704    const llama_cparams & cparams;
 705    const llama_ubatch  & ubatch;
 706
 707    const int64_t n_embd;
 708    const int64_t n_layer;
 709    const int64_t n_rot;
 710    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
 711    const int64_t n_head;
 712    const int64_t n_head_kv;
 713    const int64_t n_embd_head_k;
 714    const int64_t n_embd_k_gqa;
 715    const int64_t n_embd_head_v;
 716    const int64_t n_embd_v_gqa;
 717    const int64_t n_expert;
 718    const int64_t n_expert_used;
 719
 720    const float freq_base;
 721    const float freq_scale;
 722    const float ext_factor;
 723    const float attn_factor;
 724    const float beta_fast;
 725    const float beta_slow;
 726    const float norm_eps;
 727    const float norm_rms_eps;
 728
 729    const int64_t n_tokens;
 730    const int64_t n_outputs;
 731    const int32_t n_ctx_orig; // yarn
 732
 733    const enum llama_pooling_type pooling_type;
 734    const enum llama_rope_type    rope_type;
 735
 736    ggml_backend_sched_t sched;
 737
 738    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 739
 740    const llama_adapter_cvec     * cvec;
 741    const llama_adapter_loras    * loras;
 742    const llama_memory_context_i * mctx;
 743    const llama_cross            * cross;
 744
 745    std::map<llama_seq_id, llama_sampler *> samplers;
 746
 747    const llm_graph_cb & cb_func;
 748
 749    llm_graph_result * res;
 750
 751    ggml_context * ctx0 = nullptr;
 752    ggml_cgraph  * gf   = nullptr;
 753
 754    llm_graph_context(const llm_graph_params & params);
 755    virtual ~llm_graph_context() = default;
 756
 757    void cb(ggml_tensor * cur, const char * name, int il) const;
 758
 759    //
 760    // common
 761    //
 762
 763    ggml_tensor * build_cvec(
 764             ggml_tensor * cur,
 765                     int   il) const;
 766
 767    // do mat_mul, while optionally apply lora
 768    ggml_tensor * build_lora_mm(
 769              ggml_tensor * w,
 770              ggml_tensor * cur) const;
 771
 772    // do mat_mul_id, while optionally apply lora
 773    ggml_tensor * build_lora_mm_id(
 774              ggml_tensor * w,   // ggml_tensor * as
 775              ggml_tensor * cur, // ggml_tensor * b
 776              ggml_tensor * ids) const;
 777
 778    ggml_tensor * build_norm(
 779             ggml_tensor * cur,
 780             ggml_tensor * mw,
 781             ggml_tensor * mb,
 782           llm_norm_type   type,
 783                     int   il) const;
 784
 785    ggml_tensor * build_ffn(
 786             ggml_tensor * cur,
 787             ggml_tensor * up,
 788             ggml_tensor * up_b,
 789             ggml_tensor * up_s,
 790             ggml_tensor * gate,
 791             ggml_tensor * gate_b,
 792             ggml_tensor * gate_s,
 793             ggml_tensor * down,
 794             ggml_tensor * down_b,
 795             ggml_tensor * down_s,
 796             ggml_tensor * act_scales,
 797         llm_ffn_op_type   type_op,
 798       llm_ffn_gate_type   type_gate,
 799                     int   il) const;
 800
 801    // build MoE FFN without bias tensors
 802    ggml_tensor * build_moe_ffn(
 803             ggml_tensor * cur,
 804             ggml_tensor * gate_inp,
 805             ggml_tensor * up_exps,
 806             ggml_tensor * gate_exps,
 807             ggml_tensor * down_exps,
 808             ggml_tensor * exp_probs_b,
 809                 int64_t   n_expert,
 810                 int64_t   n_expert_used,
 811         llm_ffn_op_type   type_op,
 812                    bool   norm_w,
 813                    bool   scale_w,
 814                   float   w_scale,
 815            llama_expert_gating_func_type gating_op,
 816                     int   il,
 817             ggml_tensor * probs_in = nullptr) const;
 818
 819    ggml_tensor * build_moe_ffn(
 820             ggml_tensor * cur,
 821             ggml_tensor * gate_inp,
 822             ggml_tensor * gate_inp_b,
 823             ggml_tensor * up_exps,
 824             ggml_tensor * up_exps_b,
 825             ggml_tensor * gate_exps,
 826             ggml_tensor * gate_exps_b,
 827             ggml_tensor * down_exps,
 828             ggml_tensor * down_exps_b,
 829             ggml_tensor * exp_probs_b,
 830                 int64_t   n_expert,
 831                 int64_t   n_expert_used,
 832         llm_ffn_op_type   type_op,
 833                    bool   norm_w,
 834                    bool   scale_w,
 835                   float   w_scale,
 836            llama_expert_gating_func_type gating_op,
 837                     int   il,
 838             ggml_tensor * probs_in = nullptr) const;
 839
 840    //
 841    // inputs
 842    //
 843
 844    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
 845    ggml_tensor * build_inp_pos() const;
 846    ggml_tensor * build_inp_attn_scale() const;
 847    ggml_tensor * build_inp_out_ids() const;
 848    ggml_tensor * build_inp_mean() const;
 849    ggml_tensor * build_inp_cls() const;
 850
 851    ggml_tensor * build_inp_cross_embd() const;
 852    ggml_tensor * build_inp_pos_bucket_enc() const;
 853    ggml_tensor * build_inp_pos_bucket_dec() const;
 854    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
 855
 856    //
 857    // attention
 858    //
 859
 860    ggml_tensor * build_attn_mha(
 861            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
 862            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
 863            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
 864            ggml_tensor * kq_b,
 865            ggml_tensor * kq_mask,
 866            ggml_tensor * sinks,   // [n_head_q]
 867            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 868                  float   kq_scale,
 869                    int   il) const;
 870
 871    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
 872
 873    ggml_tensor * build_attn(
 874            llm_graph_input_attn_no_cache * inp,
 875            ggml_tensor * wo,
 876            ggml_tensor * wo_b,
 877            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 878            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
 879            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
 880            ggml_tensor * kq_b,
 881            ggml_tensor * sinks, // [n_head_q]
 882            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 883                  float   kq_scale,
 884                    int   il) const;
 885
 886    llm_graph_input_attn_kv * build_attn_inp_kv() const;
 887
 888    ggml_tensor * build_attn(
 889            llm_graph_input_attn_kv * inp,
 890            ggml_tensor * wo,
 891            ggml_tensor * wo_b,
 892            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 893            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
 894            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
 895            ggml_tensor * kq_b,
 896            ggml_tensor * sinks, // [n_head_q]
 897            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
 898                  float   kq_scale,
 899                    int   il) const;
 900
 901    llm_graph_input_attn_k  * build_attn_inp_k() const;
 902
 903    ggml_tensor * build_attn(
 904            llm_graph_input_attn_k * inp,
 905            ggml_tensor * wo,
 906            ggml_tensor * wo_b,
 907            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 908            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
 909            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
 910            ggml_tensor * kq_b,
 911            ggml_tensor * sinks, // [n_head_q]
 912            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 913                  float   kq_scale,
 914                    int   il) const;
 915
 916    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
 917
 918    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
 919    ggml_tensor * build_attn(
 920            llm_graph_input_attn_kv_iswa * inp,
 921            ggml_tensor * wo,
 922            ggml_tensor * wo_b,
 923            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 924            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
 925            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
 926            ggml_tensor * kq_b,
 927            ggml_tensor * sinks, // [n_head_q]
 928            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 929                  float   kq_scale,
 930                    int   il) const;
 931
 932    llm_graph_input_attn_cross * build_attn_inp_cross() const;
 933
 934    ggml_tensor * build_attn(
 935            llm_graph_input_attn_cross * inp,
 936            ggml_tensor * wo,
 937            ggml_tensor * wo_b,
 938            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 939            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
 940            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
 941            ggml_tensor * kq_b,
 942            ggml_tensor * sinks, // [n_head_q]
 943            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 944                  float   kq_scale,
 945                    int   il) const;
 946
 947    //
 948    // recurrent
 949    //
 950
 951    // TODO: move this implementation to llama_memory_recurrent.
 952    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
 953    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
 954    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
 955    //         `llama_memory_recurrent`
 956    ggml_tensor * build_rs(
 957            ggml_tensor * s,
 958            ggml_tensor * state_copy_main,
 959            ggml_tensor * state_copy_extra,
 960                int32_t   state_size,
 961                int32_t   n_seqs,
 962               uint32_t   n_rs,
 963               uint32_t   rs_head,
 964               uint32_t   rs_size,
 965                int32_t   rs_zero,
 966            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 967
 968    llm_graph_input_rs * build_rs_inp() const;
 969
 970    ggml_tensor * build_rs(
 971            llm_graph_input_rs * inp,
 972            ggml_tensor * s,
 973                int32_t   state_size,
 974                int32_t   n_seqs,
 975            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 976
 977    ggml_tensor * build_rwkv_token_shift_load(
 978        llm_graph_input_rs * inp,
 979        const llama_ubatch & ubatch,
 980                       int   il) const;
 981
 982    ggml_tensor * build_rwkv_token_shift_store(
 983             ggml_tensor * token_shift,
 984      const llama_ubatch & ubatch,
 985                     int   il) const;
 986    //
 987    // hybrid
 988    //
 989
 990    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
 991    llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const;
 992
 993    llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
 994
 995    //
 996    // pooling
 997    //
 998
 999    void build_pooling(
1000            ggml_tensor * cls,
1001            ggml_tensor * cls_b,
1002            ggml_tensor * cls_out,
1003            ggml_tensor * cls_out_b) const;
1004
1005    //
1006    // sampling (backend sampling)
1007    //
1008
1009    void build_sampling() const;
1010
1011    //
1012    // dense (out)
1013    //
1014
1015    void build_dense_out(
1016            ggml_tensor * dense_2,
1017            ggml_tensor * dense_3) const;
1018};
1019
1020// TODO: better name
1021int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);