1#include "llama-graph.h"
   2
   3#include "llama-impl.h"
   4#include "llama-batch.h"
   5#include "llama-cparams.h"
   6
   7#include "llama-kv-cache.h"
   8#include "llama-kv-cache-iswa.h"
   9#include "llama-memory-hybrid.h"
  10#include "llama-memory-hybrid-iswa.h"
  11#include "llama-memory-recurrent.h"
  12
  13#include <cassert>
  14#include <cmath>
  15#include <cstring>
  16#include <numeric>
  17#include <sstream>
  18#include <unordered_set>
  19
  20void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
  21    if (ubatch->token) {
  22        const int64_t n_tokens = ubatch->n_tokens;
  23
  24        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
  25    }
  26
  27    if (ubatch->embd) {
  28        GGML_ASSERT(n_embd == embd->ne[0]);
  29
  30        const int64_t n_tokens = ubatch->n_tokens;
  31
  32        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
  33    }
  34}
  35
  36bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
  37    bool res = true;
  38
  39    res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
  40    res &= (!params.ubatch.embd)  || (embd   &&   embd->ne[1] == params.ubatch.n_tokens);
  41
  42    return res;
  43}
  44
  45void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
  46    if (ubatch->pos && pos) {
  47        const int64_t n_tokens = ubatch->n_tokens;
  48
  49        if (ubatch->token && n_pos_per_embd == 4) {
  50            // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
  51            // the 3 first dims are the same, and 4th dim is all 0
  52            std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
  53            // copy the first dimension
  54            for (int i = 0; i < n_tokens; ++i) {
  55                pos_data[               i] = ubatch->pos[i];
  56                pos_data[    n_tokens + i] = ubatch->pos[i];
  57                pos_data[2 * n_tokens + i] = ubatch->pos[i];
  58                pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
  59            }
  60            ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
  61        } else {
  62            ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
  63        }
  64    }
  65}
  66
  67bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
  68    bool res = true;
  69
  70    res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
  71
  72    return res;
  73}
  74
  75void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
  76    if (ubatch->pos && attn_scale) {
  77        const int64_t n_tokens = ubatch->n_tokens;
  78
  79        GGML_ASSERT(f_attn_temp_scale != 0.0f);
  80        GGML_ASSERT(n_attn_temp_floor_scale != 0);
  81
  82        std::vector<float> attn_scale_data(n_tokens, 0.0f);
  83        for (int i = 0; i < n_tokens; ++i) {
  84            const float pos = ubatch->pos[i];
  85            attn_scale_data[i] = std::log(
  86                std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
  87            ) * f_attn_temp_scale + 1.0;
  88        }
  89
  90        ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
  91    }
  92}
  93
  94void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
  95    if (pos_bucket) {
  96        const int64_t n_tokens = ubatch->n_tokens;
  97
  98        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
  99        GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 100
 101        int32_t * data = (int32_t *) pos_bucket->data;
 102
 103        for (int j = 0; j < n_tokens; ++j) {
 104            for (int i = 0; i < n_tokens; ++i) {
 105                data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
 106            }
 107        }
 108    }
 109}
 110
 111void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
 112    if (pos_bucket) {
 113        mctx->set_input_pos_bucket(pos_bucket, ubatch);
 114    }
 115}
 116
 117void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
 118    GGML_ASSERT(out_ids);
 119
 120    const int64_t n_tokens = ubatch->n_tokens;
 121
 122    GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
 123    int32_t * data = (int32_t *) out_ids->data;
 124
 125    if (n_outputs == n_tokens) {
 126        for (int i = 0; i < n_tokens; ++i) {
 127            data[i] = i;
 128        }
 129
 130        return;
 131    }
 132
 133    GGML_ASSERT(ubatch->output);
 134
 135    int n_outputs = 0;
 136
 137    for (int i = 0; i < n_tokens; ++i) {
 138        if (ubatch->output[i]) {
 139            data[n_outputs++] = i;
 140        }
 141    }
 142}
 143
 144bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
 145    bool res = true;
 146
 147    res &= n_outputs == params.n_outputs;
 148
 149    return res;
 150}
 151
 152void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
 153    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
 154        const int64_t n_tokens     = ubatch->n_tokens;
 155        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
 156        const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
 157
 158        GGML_ASSERT(mean);
 159        GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
 160
 161        float * data = (float *) mean->data;
 162        memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
 163
 164        std::vector<uint64_t> sums(n_seqs_unq, 0);
 165        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
 166            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
 167                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
 168                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 169
 170                sums[seq_idx] += ubatch->n_seq_tokens;
 171            }
 172        }
 173
 174        std::vector<float> div(n_seqs_unq, 0.0f);
 175        for (int s = 0; s < n_seqs_unq; ++s) {
 176            const uint64_t sum = sums[s];
 177            if (sum > 0) {
 178                div[s] = 1.0f/float(sum);
 179            }
 180        }
 181
 182        for (int i = 0; i < n_tokens; i += n_seq_tokens) {
 183            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
 184                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
 185                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 186
 187                for (int j = 0; j < n_seq_tokens; ++j) {
 188                    data[seq_idx*n_tokens + i + j] = div[seq_idx];
 189                }
 190            }
 191        }
 192    }
 193}
 194
 195void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
 196    const int64_t n_tokens     = ubatch->n_tokens;
 197    const int64_t n_seqs_unq   = ubatch->n_seqs_unq;
 198
 199    if (cparams.embeddings && (
 200        cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  ||
 201        cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
 202        cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
 203    )) {
 204        GGML_ASSERT(cls);
 205        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
 206
 207        uint32_t * data = (uint32_t *) cls->data;
 208        memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
 209
 210        std::vector<int> target_pos(n_seqs_unq, -1);
 211        std::vector<int> target_row(n_seqs_unq, -1);
 212
 213        const bool last = (
 214             cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
 215            (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
 216        );
 217
 218        for (int i = 0; i < n_tokens; ++i) {
 219            const llama_pos pos = ubatch->pos[i];
 220
 221            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
 222                const llama_seq_id seq_id  = ubatch->seq_id[i][s];
 223                const int32_t      seq_idx = ubatch->seq_idx[seq_id];
 224
 225                if (
 226                    (target_pos[seq_idx] == -1) ||
 227                    ( last && pos >= target_pos[seq_idx]) ||
 228                    (!last && pos <  target_pos[seq_idx])
 229                ) {
 230                    target_pos[seq_idx] = pos;
 231                    target_row[seq_idx] = i;
 232                }
 233            }
 234        }
 235
 236        for (int s = 0; s < n_seqs_unq; ++s) {
 237            if (target_row[s] >= 0) {
 238                data[s] = target_row[s];
 239            }
 240        }
 241    }
 242}
 243
 244void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
 245    GGML_UNUSED(ubatch);
 246
 247    const int64_t n_rs = mctx->get_n_rs();
 248
 249    if (s_copy) {
 250        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
 251        int32_t * data = (int32_t *) s_copy->data;
 252
 253        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
 254        for (uint32_t i = 0; i < n_rs; ++i) {
 255            data[i] = mctx->s_copy(i);
 256        }
 257    }
 258}
 259
 260bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
 261    const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
 262
 263    this->mctx = mctx;
 264
 265    bool res = true;
 266
 267    res &= s_copy->ne[0] == mctx->get_n_rs();
 268
 269    res &= s_copy_main->ne[0]  == params.ubatch.n_seqs;
 270    res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
 271
 272    res &= head == mctx->get_head();
 273    res &= rs_z == mctx->get_rs_z();
 274
 275    return res;
 276}
 277
 278void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
 279    GGML_UNUSED(ubatch);
 280
 281    if (cross_embd && !cross->v_embd.empty()) {
 282        assert(cross_embd->type == GGML_TYPE_F32);
 283
 284        ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
 285    }
 286}
 287
 288static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
 289    LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
 290    const char * swa_type_str = "unknown";
 291
 292    switch (swa_type) {
 293        case LLAMA_SWA_TYPE_NONE:      swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
 294        case LLAMA_SWA_TYPE_STANDARD:  swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
 295        case LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
 296        case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
 297    };
 298
 299    LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
 300    LLAMA_LOG_DEBUG("%s: '0' = can attend, 'โˆž' = masked\n", __func__);
 301    LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
 302
 303    LLAMA_LOG_DEBUG("    ");
 304    for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
 305        LLAMA_LOG_DEBUG("%2d", j);
 306    }
 307    LLAMA_LOG_DEBUG("\n");
 308
 309    for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
 310        LLAMA_LOG_DEBUG(" %2d ", i);
 311        for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
 312            float val = data[i * n_kv + j];
 313            if (val == -INFINITY) {
 314                LLAMA_LOG_DEBUG(" โˆž");
 315            } else {
 316                LLAMA_LOG_DEBUG(" 0");
 317            }
 318        }
 319        LLAMA_LOG_DEBUG("\n");
 320    }
 321}
 322
 323void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 324    const int64_t n_kv     = ubatch->n_tokens;
 325    const int64_t n_tokens = ubatch->n_tokens;
 326
 327    const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
 328        for (int i1 = 0; i1 < n_tokens; ++i1) {
 329            const llama_seq_id s1 = ubatch->seq_id[i1][0];
 330            const llama_pos    p1 = ubatch->pos[i1];
 331
 332            const uint64_t idst = i1*n_kv;
 333
 334            for (int i0 = 0; i0 < n_tokens; ++i0) {
 335                const llama_seq_id s0 = ubatch->seq_id[i0][0];
 336                const llama_pos p0    = ubatch->pos[i0];
 337
 338                // mask different sequences
 339                if (s0 != s1) {
 340                    continue;
 341                }
 342
 343                // mask future tokens
 344                if (cparams.causal_attn && p0 > p1) {
 345                    continue;
 346                }
 347
 348                // apply SWA if any
 349                if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
 350                    continue;
 351                }
 352
 353                data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
 354            }
 355        }
 356    };
 357
 358    {
 359        GGML_ASSERT(self_kq_mask);
 360        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
 361
 362        float * data = (float *) self_kq_mask->data;
 363
 364        std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
 365
 366        fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
 367
 368        if (debug) {
 369            print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
 370        }
 371    }
 372
 373    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
 374        GGML_ASSERT(self_kq_mask_swa);
 375        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
 376
 377        float * data = (float *) self_kq_mask_swa->data;
 378
 379        std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
 380
 381        fill_mask(data, hparams.n_swa, hparams.swa_type);
 382
 383        if (debug) {
 384            print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
 385        }
 386    }
 387}
 388
 389void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
 390    mctx->set_input_k_idxs(self_k_idxs, ubatch);
 391    mctx->set_input_v_idxs(self_v_idxs, ubatch);
 392
 393    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 394}
 395
 396bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
 397    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
 398
 399    this->mctx = mctx;
 400
 401    bool res = true;
 402
 403    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
 404  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 405
 406    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
 407    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
 408
 409    return res;
 410}
 411
 412void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
 413    mctx->set_input_k_idxs(self_k_idxs, ubatch);
 414
 415    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 416}
 417
 418bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
 419    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
 420
 421    this->mctx = mctx;
 422
 423    bool res = true;
 424
 425    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
 426
 427    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
 428    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
 429
 430    return res;
 431}
 432
 433void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
 434    mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
 435    mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 436
 437    mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 438
 439    mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
 440    mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
 441
 442    mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 443}
 444
 445bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
 446    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
 447
 448    this->mctx = mctx;
 449
 450    bool res = true;
 451
 452    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
 453  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 454
 455    res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
 456  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 457
 458    res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
 459    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
 460
 461    res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
 462    res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
 463
 464    return res;
 465}
 466
 467void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 468    GGML_ASSERT(cross_kq_mask);
 469
 470    const int64_t n_enc    = cross_kq_mask->ne[0];
 471    const int64_t n_tokens = ubatch->n_tokens;
 472
 473    GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
 474    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 475
 476    float * data = (float *) cross_kq_mask->data;
 477
 478    for (int i = 0; i < n_tokens; ++i) {
 479        for (int j = 0; j < n_enc; ++j) {
 480            float f = -INFINITY;
 481
 482            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
 483                const llama_seq_id seq_id = ubatch->seq_id[i][s];
 484
 485                if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
 486                    f = 0.0f;
 487                }
 488            }
 489
 490            data[i*n_enc + j] = f;
 491        }
 492    }
 493}
 494
 495void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
 496    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
 497    mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
 498
 499    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
 500
 501    const int64_t n_rs = mctx->get_recr()->get_n_rs();
 502
 503    if (inp_rs->s_copy) {
 504        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
 505        int32_t * data = (int32_t *) inp_rs->s_copy->data;
 506
 507        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
 508        for (uint32_t i = 0; i < n_rs; ++i) {
 509            data[i] = mctx->get_recr()->s_copy(i);
 510        }
 511    }
 512}
 513
 514bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
 515    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
 516
 517    this->mctx = mctx;
 518
 519    bool res = true;
 520
 521    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
 522  //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 523
 524    res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
 525    res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
 526
 527    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 528
 529    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
 530    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
 531
 532    res &= inp_rs->head == mctx->get_recr()->get_head();
 533    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
 534
 535    return res;
 536}
 537
 538// TODO: Hybrid input classes are a bit redundant.
 539// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
 540// Refactoring is required in the future.
 541void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
 542    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
 543
 544    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
 545
 546    const int64_t n_rs = mctx->get_recr()->get_n_rs();
 547
 548    if (inp_rs->s_copy) {
 549        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
 550        int32_t * data = (int32_t *) inp_rs->s_copy->data;
 551
 552        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
 553        for (uint32_t i = 0; i < n_rs; ++i) {
 554            data[i] = mctx->get_recr()->s_copy(i);
 555        }
 556    }
 557}
 558
 559bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
 560    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
 561
 562    this->mctx = mctx;
 563
 564    bool res = true;
 565
 566    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
 567
 568    res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
 569    res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
 570
 571    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 572
 573    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
 574    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
 575
 576    res &= inp_rs->head == mctx->get_recr()->get_head();
 577    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
 578
 579    return res;
 580}
 581
 582void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
 583    const auto * attn_ctx = mctx->get_attn();
 584
 585    // base tensors may not be allocated if there are no non-SWA attention layers
 586    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
 587        attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
 588        attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
 589
 590        attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
 591    }
 592
 593    // swa tensors may not be allocated if there are no SWA attention layers
 594    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
 595        attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
 596        attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
 597
 598        attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
 599    }
 600
 601    const int64_t n_rs = mctx->get_recr()->get_n_rs();
 602
 603    if (inp_rs->s_copy) {
 604        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
 605        int32_t * data = (int32_t *) inp_rs->s_copy->data;
 606
 607        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
 608        for (uint32_t i = 0; i < n_rs; ++i) {
 609            data[i] = mctx->get_recr()->s_copy(i);
 610        }
 611    }
 612}
 613
 614bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
 615    const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
 616
 617    this->mctx = mctx;
 618
 619    bool res = true;
 620
 621    const auto * attn_ctx = mctx->get_attn();
 622
 623    // base tensors may not be allocated if there are no non-SWA attention layers
 624    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
 625        res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
 626      //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 627
 628        res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
 629        res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
 630    }
 631
 632    // swa tensors may not be allocated if there are no SWA attention layers
 633    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
 634        res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
 635      //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 636
 637        res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
 638        res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
 639    }
 640
 641    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 642
 643    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
 644    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
 645
 646    res &= inp_rs->head == mctx->get_recr()->get_head();
 647    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
 648
 649    return res;
 650}
 651
 652void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
 653    // set the inputs only for the active samplers in the current ubatch
 654    std::unordered_set<llama_seq_id> active_samplers;
 655    for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
 656        if (ubatch->output[i]) {
 657            llama_seq_id seq_id = ubatch->seq_id[i][0];
 658            active_samplers.insert(seq_id);
 659        }
 660    }
 661
 662    for (auto seq_id : active_samplers) {
 663        if (samplers.find(seq_id) == samplers.end()) {
 664            continue;
 665        }
 666
 667        auto & sampler = samplers[seq_id];
 668
 669        if (sampler->iface->backend_set_input) {
 670            sampler->iface->backend_set_input(sampler);
 671        }
 672    }
 673}
 674
 675bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
 676    if (samplers.size() != params.samplers.size()) {
 677        return false;
 678    }
 679
 680    for (const auto & [seq_id, sampler] : params.samplers) {
 681        if (samplers[seq_id] != sampler) {
 682            return false;
 683        }
 684    }
 685
 686    return true;
 687}
 688
 689//
 690// llm_graph_result
 691//
 692
 693llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
 694    reset();
 695
 696    const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
 697    debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
 698}
 699
 700int64_t llm_graph_result::get_max_nodes() const {
 701    return max_nodes;
 702}
 703
 704void llm_graph_result::reset() {
 705    t_inp_tokens  = nullptr;
 706    t_inp_embd    = nullptr;
 707    t_logits      = nullptr;
 708    t_embd        = nullptr;
 709    t_embd_pooled = nullptr;
 710    t_sampled.clear();
 711    t_sampled_probs.clear();
 712    t_sampled_logits.clear();
 713    t_candidates.clear();
 714
 715    params = {};
 716
 717    inputs.clear();
 718
 719    buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
 720
 721    ggml_init_params params = {
 722        /*.mem_size   =*/ buf_compute_meta.size(),
 723        /*.mem_buffer =*/ buf_compute_meta.data(),
 724        /*.no_alloc   =*/ true,
 725    };
 726
 727    ctx_compute.reset(ggml_init(params));
 728
 729    gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
 730}
 731
 732void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
 733    for (auto & input : inputs) {
 734        input->set_input(ubatch);
 735    }
 736}
 737
 738void llm_graph_result::set_outputs() {
 739    if (t_logits != nullptr) {
 740        ggml_set_output(t_logits);
 741    }
 742    if (t_embd != nullptr) {
 743        ggml_set_output(t_embd);
 744    }
 745    if (t_embd_pooled != nullptr) {
 746        ggml_set_output(t_embd_pooled);
 747    }
 748    for (auto & [seq_id, t] : t_sampled) {
 749        if (t != nullptr) {
 750            ggml_set_output(t);
 751        }
 752    }
 753    for (auto & [seq_id, t] : t_sampled_probs) {
 754        if (t != nullptr) {
 755            ggml_set_output(t);
 756        }
 757    }
 758    for (auto & [seq_id, t] : t_sampled_logits) {
 759        if (t != nullptr) {
 760            ggml_set_output(t);
 761        }
 762    }
 763    for (auto & [seq_id, t] : t_candidates) {
 764        if (t != nullptr) {
 765            ggml_set_output(t);
 766        }
 767    }
 768}
 769
 770bool llm_graph_result::can_reuse(const llm_graph_params & params) {
 771    if (!this->params.allow_reuse(params)) {
 772        if (debug > 1) {
 773            LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
 774        }
 775
 776        return false;
 777    }
 778
 779    if (debug > 1) {
 780        LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
 781    }
 782
 783    bool res = true;
 784
 785    for (auto & input : inputs) {
 786        const bool cur = input->can_reuse(params);
 787
 788        if (debug > 1) {
 789            LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
 790        }
 791
 792        res = res && cur;
 793    }
 794
 795    if (debug > 0) {
 796        LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
 797    }
 798
 799    return res;
 800}
 801
 802llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
 803    inputs.emplace_back(std::move(input));
 804    return inputs.back().get();
 805}
 806
 807void llm_graph_result::set_params(const llm_graph_params & params) {
 808    this->params = params;
 809}
 810
 811//
 812// llm_graph_context
 813//
 814
 815llm_graph_context::llm_graph_context(const llm_graph_params & params) :
 816    arch             (params.arch),
 817    hparams          (params.hparams),
 818    cparams          (params.cparams),
 819    ubatch           (params.ubatch),
 820    n_embd           (hparams.n_embd),
 821    n_layer          (hparams.n_layer),
 822    n_rot            (hparams.n_rot),
 823    n_ctx            (cparams.n_ctx),
 824    n_head           (hparams.n_head()),
 825    n_head_kv        (hparams.n_head_kv()),
 826    n_embd_head_k    (hparams.n_embd_head_k),
 827    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
 828    n_embd_head_v    (hparams.n_embd_head_v),
 829    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
 830    n_expert         (hparams.n_expert),
 831    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
 832    freq_base        (cparams.rope_freq_base),
 833    freq_scale       (cparams.rope_freq_scale),
 834    ext_factor       (cparams.yarn_ext_factor),
 835    attn_factor      (cparams.yarn_attn_factor),
 836    beta_fast        (cparams.yarn_beta_fast),
 837    beta_slow        (cparams.yarn_beta_slow),
 838    norm_eps         (hparams.f_norm_eps),
 839    norm_rms_eps     (hparams.f_norm_rms_eps),
 840    n_tokens         (ubatch.n_tokens),
 841    n_outputs        (params.n_outputs),
 842    n_ctx_orig       (cparams.n_ctx_orig_yarn),
 843    pooling_type     (cparams.pooling_type),
 844    rope_type        (hparams.rope_type),
 845    sched            (params.sched),
 846    backend_cpu      (params.backend_cpu),
 847    cvec             (params.cvec),
 848    loras            (params.loras),
 849    mctx             (params.mctx),
 850    cross            (params.cross),
 851    samplers         (params.samplers),
 852    cb_func          (params.cb),
 853    res              (params.res),
 854    ctx0             (res->get_ctx()),
 855    gf               (res->get_gf()) {
 856        res->set_params(params);
 857    }
 858
 859void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
 860    if (cb_func) {
 861        cb_func(ubatch, cur, name, il);
 862    }
 863}
 864
 865ggml_tensor * llm_graph_context::build_cvec(
 866         ggml_tensor * cur,
 867                 int   il) const {
 868    return cvec->apply_to(ctx0, cur, il);
 869}
 870
 871ggml_tensor * llm_graph_context::build_lora_mm(
 872          ggml_tensor * w,
 873          ggml_tensor * cur) const {
 874    ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
 875
 876    for (const auto & lora : *loras) {
 877        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
 878        if (lw == nullptr) {
 879            continue;
 880        }
 881
 882        const float adapter_scale = lora.second;
 883        const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
 884
 885        ggml_tensor * ab_cur = ggml_mul_mat(
 886                ctx0, lw->b,
 887                ggml_mul_mat(ctx0, lw->a, cur)
 888                );
 889
 890        ab_cur = ggml_scale(ctx0, ab_cur, scale);
 891        res = ggml_add(ctx0, res, ab_cur);
 892    }
 893
 894    return res;
 895}
 896
 897ggml_tensor * llm_graph_context::build_lora_mm_id(
 898          ggml_tensor * w,   // ggml_tensor * as
 899          ggml_tensor * cur, // ggml_tensor * b
 900          ggml_tensor * ids) const {
 901    ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
 902    for (const auto & lora : *loras) {
 903        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
 904        if (lw == nullptr) {
 905            continue;
 906        }
 907
 908        const float alpha = lora.first->alpha;
 909        const float rank  = (float) lw->b->ne[0];
 910        const float scale = alpha ? lora.second * alpha / rank : lora.second;
 911
 912        ggml_tensor * ab_cur = ggml_mul_mat_id(
 913                ctx0, lw->b,
 914                ggml_mul_mat_id(ctx0, lw->a, cur, ids),
 915                ids
 916                );
 917
 918        ab_cur = ggml_scale(ctx0, ab_cur, scale);
 919        res = ggml_add(ctx0, res, ab_cur);
 920    }
 921
 922    return res;
 923}
 924
 925ggml_tensor * llm_graph_context::build_norm(
 926         ggml_tensor * cur,
 927         ggml_tensor * mw,
 928         ggml_tensor * mb,
 929       llm_norm_type   type,
 930                 int   il) const {
 931    switch (type) {
 932        case LLM_NORM:       cur = ggml_norm    (ctx0, cur, hparams.f_norm_eps);     break;
 933        case LLM_NORM_RMS:   cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
 934        case LLM_NORM_GROUP:
 935            {
 936                cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
 937                cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
 938                cur = ggml_reshape_2d(ctx0, cur, cur->ne[0],    cur->ne[2]);
 939            } break;
 940    }
 941
 942    if (mw || mb) {
 943        cb(cur, "norm", il);
 944    }
 945
 946    if (mw) {
 947        cur = ggml_mul(ctx0, cur, mw);
 948        if (mb) {
 949            cb(cur, "norm_w", il);
 950        }
 951    }
 952
 953    if (mb) {
 954        cur = ggml_add(ctx0, cur, mb);
 955    }
 956
 957    return cur;
 958}
 959
 960ggml_tensor * llm_graph_context::build_ffn(
 961         ggml_tensor * cur,
 962         ggml_tensor * up,
 963         ggml_tensor * up_b,
 964         ggml_tensor * up_s,
 965         ggml_tensor * gate,
 966         ggml_tensor * gate_b,
 967         ggml_tensor * gate_s,
 968         ggml_tensor * down,
 969         ggml_tensor * down_b,
 970         ggml_tensor * down_s,
 971         ggml_tensor * act_scales,
 972     llm_ffn_op_type   type_op,
 973   llm_ffn_gate_type   type_gate,
 974                 int   il) const {
 975    ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
 976    cb(tmp, "ffn_up", il);
 977
 978    if (up_b) {
 979        tmp = ggml_add(ctx0, tmp, up_b);
 980        cb(tmp, "ffn_up_b", il);
 981    }
 982
 983    if (up_s) {
 984        tmp = ggml_mul(ctx0, tmp, up_s);
 985        cb(tmp, "ffn_up_s", il);
 986    }
 987
 988    if (gate) {
 989        switch (type_gate) {
 990            case LLM_FFN_SEQ:
 991                {
 992                    cur = build_lora_mm(gate, tmp);
 993                    cb(cur, "ffn_gate", il);
 994                } break;
 995            case LLM_FFN_PAR:
 996                {
 997                    cur = build_lora_mm(gate, cur);
 998                    cb(cur, "ffn_gate", il);
 999                } break;
1000        }
1001
1002        if (gate_b) {
1003            cur = ggml_add(ctx0, cur, gate_b);
1004            cb(cur, "ffn_gate_b", il);
1005        }
1006
1007        if (gate_s) {
1008            cur = ggml_mul(ctx0, cur, gate_s);
1009            cb(cur, "ffn_gate_s", il);
1010        }
1011
1012    } else {
1013        cur = tmp;
1014    }
1015
1016    switch (type_op) {
1017        case LLM_FFN_SILU:
1018            if (gate && type_gate == LLM_FFN_PAR) {
1019                // Step35: HF clamps gate (after SiLU) and up before multiplication
1020                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1021                    const float limit = hparams.swiglu_clamp_shexp[il];
1022                    constexpr float eps = 1e-6f;
1023                    if (limit > eps) {
1024                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1025                        cb(gate_act, "ffn_silu", il);
1026                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1027                        cb(gate_act, "ffn_silu_clamped", il);
1028
1029                        tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1030                        cb(tmp, "ffn_up_clamped", il);
1031
1032                        cur = ggml_mul(ctx0, gate_act, tmp);
1033                        cb(cur, "ffn_swiglu_limited", il);
1034                        type_gate = LLM_FFN_SEQ;
1035                        break;
1036                    }
1037                }
1038
1039                cur = ggml_swiglu_split(ctx0, cur, tmp);
1040                cb(cur, "ffn_swiglu", il);
1041                type_gate = LLM_FFN_SEQ;
1042            } else {
1043                cur = ggml_silu(ctx0, cur);
1044                cb(cur, "ffn_silu", il);
1045            } break;
1046        case LLM_FFN_GELU:
1047            if (gate && type_gate == LLM_FFN_PAR) {
1048                cur = ggml_geglu_split(ctx0, cur, tmp);
1049                cb(cur, "ffn_geglu", il);
1050                type_gate = LLM_FFN_SEQ;
1051            } else {
1052                cur = ggml_gelu(ctx0, cur);
1053                cb(cur, "ffn_gelu", il);
1054                if (act_scales != NULL) {
1055                    cur = ggml_div(ctx0, cur, act_scales);
1056                    cb(cur, "ffn_act", il);
1057                }
1058            } break;
1059        case LLM_FFN_RELU:
1060            if (gate && type_gate == LLM_FFN_PAR) {
1061                cur = ggml_reglu_split(ctx0, cur, tmp);
1062                cb(cur, "ffn_reglu", il);
1063                type_gate = LLM_FFN_SEQ;
1064            } else {
1065                cur = ggml_relu(ctx0, cur);
1066                cb(cur, "ffn_relu", il);
1067            } break;
1068        case LLM_FFN_RELU_SQR:
1069            {
1070                cur = ggml_relu(ctx0, cur);
1071                cb(cur, "ffn_relu", il);
1072
1073                cur = ggml_sqr(ctx0, cur);
1074                cb(cur, "ffn_sqr(relu)", il);
1075            } break;
1076        case LLM_FFN_SWIGLU:
1077            {
1078                cur = ggml_swiglu(ctx0, cur);
1079                cb(cur, "ffn_swiglu", il);
1080            } break;
1081        case LLM_FFN_GEGLU:
1082            {
1083                cur = ggml_geglu(ctx0, cur);
1084                cb(cur, "ffn_geglu", il);
1085            } break;
1086        case LLM_FFN_REGLU:
1087            {
1088                cur = ggml_reglu(ctx0, cur);
1089                cb(cur, "ffn_reglu", il);
1090            } break;
1091        default:
1092            GGML_ABORT("fatal error");
1093    }
1094
1095    if (gate && type_gate == LLM_FFN_PAR) {
1096        cur = ggml_mul(ctx0, cur, tmp);
1097        cb(cur, "ffn_gate_par", il);
1098    }
1099
1100    if (down) {
1101        cur = build_lora_mm(down, cur);
1102        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1103            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1104            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1105        }
1106    }
1107
1108    if (down_b) {
1109        cb(cur, "ffn_down", il);
1110    }
1111
1112    if (down_b) {
1113        cur = ggml_add(ctx0, cur, down_b);
1114    }
1115
1116    if (down_s) {
1117        cur = ggml_mul(ctx0, cur, down_s);
1118        cb(cur, "ffn_down_s", il);
1119    }
1120
1121    return cur;
1122}
1123
1124ggml_tensor * llm_graph_context::build_moe_ffn(
1125         ggml_tensor * cur,
1126         ggml_tensor * gate_inp,
1127         ggml_tensor * up_exps,
1128         ggml_tensor * gate_exps,
1129         ggml_tensor * down_exps,
1130         ggml_tensor * exp_probs_b,
1131             int64_t   n_expert,
1132             int64_t   n_expert_used,
1133     llm_ffn_op_type   type_op,
1134                bool   norm_w,
1135                bool   scale_w,
1136               float   w_scale,
1137         llama_expert_gating_func_type gating_op,
1138                 int   il,
1139         ggml_tensor * probs_in) const {
1140    return build_moe_ffn(
1141        cur,
1142        gate_inp,  /* gate_inp_b  */ nullptr,
1143        up_exps,   /* up_exps_b   */ nullptr,
1144        gate_exps, /* gate_exps_b */ nullptr,
1145        down_exps, /* down_exps_b */ nullptr,
1146        exp_probs_b,
1147        n_expert,
1148        n_expert_used,
1149        type_op,
1150        norm_w,
1151        scale_w,
1152        w_scale,
1153        gating_op,
1154        il,
1155        probs_in
1156    );
1157}
1158
1159ggml_tensor * llm_graph_context::build_moe_ffn(
1160         ggml_tensor * cur,
1161         ggml_tensor * gate_inp,
1162         ggml_tensor * gate_inp_b,
1163         ggml_tensor * up_exps,
1164         ggml_tensor * up_exps_b,
1165         ggml_tensor * gate_exps,
1166         ggml_tensor * gate_exps_b,
1167         ggml_tensor * down_exps,
1168         ggml_tensor * down_exps_b,
1169         ggml_tensor * exp_probs_b,
1170             int64_t   n_expert,
1171             int64_t   n_expert_used,
1172     llm_ffn_op_type   type_op,
1173                bool   norm_w,
1174                bool   scale_w,
1175               float   w_scale,
1176        llama_expert_gating_func_type gating_op,
1177                 int   il,
1178         ggml_tensor * probs_in) const {
1179    const int64_t n_embd   = cur->ne[0];
1180    const int64_t n_tokens = cur->ne[1];
1181    const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
1182
1183    ggml_tensor * logits = nullptr;
1184
1185    if (probs_in == nullptr) {
1186        logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1187        cb(logits, "ffn_moe_logits", il);
1188    } else {
1189        logits = probs_in;
1190    }
1191
1192    if (gate_inp_b) {
1193        logits = ggml_add(ctx0, logits, gate_inp_b);
1194        cb(logits, "ffn_moe_logits_biased", il);
1195    }
1196
1197    ggml_tensor * probs = nullptr;
1198    switch (gating_op) {
1199        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
1200            {
1201                probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
1202            } break;
1203        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
1204            {
1205                probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1206            } break;
1207        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1208            {
1209                probs = logits; // [n_expert, n_tokens]
1210            } break;
1211        default:
1212            GGML_ABORT("fatal error");
1213    }
1214    cb(probs, "ffn_moe_probs", il);
1215
1216    // add experts selection bias - introduced in DeepSeek V3
1217    // leave probs unbiased as it's later used to get expert weights
1218    ggml_tensor * selection_probs = probs;
1219    if (exp_probs_b != nullptr) {
1220        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
1221        cb(selection_probs, "ffn_moe_probs_biased", il);
1222    }
1223
1224    // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
1225    // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
1226    if (arch == LLM_ARCH_LLAMA4) {
1227        selection_probs = logits;
1228    }
1229
1230    if (arch == LLM_ARCH_GROVEMOE) {
1231        selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1232        cb(selection_probs, "ffn_moe_probs_biased", il);
1233    }
1234
1235    // select top n_group_used expert groups
1236    // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1237    if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1238        const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1239
1240        // organize experts into n_expert_groups
1241        ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1242
1243        ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1244        group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1245
1246        // get top n_group_used expert groups
1247        group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1248        group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1249
1250        ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1251        cb(expert_groups, "ffn_moe_group_topk", il);
1252
1253        // mask out the other groups
1254        selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1255        selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1256        selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1257        cb(selection_probs, "ffn_moe_probs_masked", il);
1258    }
1259
1260    // select experts
1261    ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1262    cb(selected_experts->src[0], "ffn_moe_argsort", il);
1263    cb(selected_experts, "ffn_moe_topk", il);
1264
1265    if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1266        // TODO: Use scalar div instead when/if implemented
1267        ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1268        selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1269        probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1270    } else {
1271        probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1272    }
1273
1274    ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
1275    cb(weights, "ffn_moe_weights", il);
1276
1277
1278    if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1279        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1280        weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1281        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1282        cb(weights, "ffn_moe_weights_softmax", il);
1283    }
1284
1285    if (norm_w) {
1286        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1287
1288        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
1289        cb(weights_sum, "ffn_moe_weights_sum", il);
1290
1291        // Avoid division by zero, clamp to smallest number representable by F16
1292        weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1293        cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1294
1295        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
1296        cb(weights, "ffn_moe_weights_norm", il);
1297
1298        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1299    }
1300    if (scale_w) {
1301        weights = ggml_scale(ctx0, weights, w_scale);
1302        cb(weights, "ffn_moe_weights_scaled", il);
1303    }
1304
1305    //call early so that topk-moe can be used
1306    ggml_build_forward_expand(gf, weights);
1307
1308    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
1309
1310    if (weight_before_ffn) {
1311        // repeat cur to [n_embd, n_expert_used, n_tokens]
1312        ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
1313        cur = ggml_mul(ctx0, repeated, weights);
1314        cb(cur, "ffn_moe_weighted", il);
1315    }
1316
1317    ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1318    cb(up, "ffn_moe_up", il);
1319
1320    if (up_exps_b) {
1321        up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1322        cb(up, "ffn_moe_up_biased", il);
1323    }
1324
1325    ggml_tensor * experts = nullptr;
1326    if (gate_exps) {
1327        cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1328        cb(cur, "ffn_moe_gate", il);
1329    } else {
1330        cur = up;
1331    }
1332
1333    if (gate_exps_b) {
1334        cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1335        cb(cur, "ffn_moe_gate_biased", il);
1336    }
1337
1338    switch (type_op) {
1339        case LLM_FFN_SILU:
1340            if (gate_exps) {
1341                // Step35: per-layer clamp for routed experts
1342                if (arch == LLM_ARCH_STEP35 && il >= 0) {
1343                    const float limit = hparams.swiglu_clamp_exp[il];
1344                    constexpr float eps = 1e-6f;
1345                    if (limit > eps) {
1346                        ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1347                        cb(gate_act, "ffn_moe_silu", il);
1348                        gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1349                        cb(gate_act, "ffn_moe_silu_clamped", il);
1350
1351                        up = ggml_clamp(ctx0, up, -limit, limit);
1352                        cb(up, "ffn_moe_up_clamped", il);
1353
1354                        cur = ggml_mul(ctx0, gate_act, up);
1355                        cb(cur, "ffn_moe_swiglu_limited", il);
1356                        break;
1357                    }
1358                }
1359
1360                cur = ggml_swiglu_split(ctx0, cur, up);
1361                cb(cur, "ffn_moe_swiglu", il);
1362            } else {
1363                cur = ggml_silu(ctx0, cur);
1364                cb(cur, "ffn_moe_silu", il);
1365            } break;
1366        case LLM_FFN_GELU:
1367            if (gate_exps) {
1368                cur = ggml_geglu_split(ctx0, cur, up);
1369                cb(cur, "ffn_moe_geglu", il);
1370            } else {
1371                cur = ggml_gelu(ctx0, cur);
1372                cb(cur, "ffn_moe_gelu", il);
1373            } break;
1374        case LLM_FFN_SWIGLU_OAI_MOE:
1375            {
1376                // TODO: move to hparams?
1377                constexpr float alpha = 1.702f;
1378                constexpr float limit = 7.0f;
1379                cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1380                cb(cur, "ffn_moe_swiglu_oai", il);
1381            } break;
1382        case LLM_FFN_RELU:
1383            if (gate_exps) {
1384                cur = ggml_reglu_split(ctx0, cur, up);
1385                cb(cur, "ffn_moe_reglu", il);
1386            } else {
1387                cur = ggml_relu(ctx0, cur);
1388                cb(cur, "ffn_moe_relu", il);
1389            } break;
1390        case LLM_FFN_RELU_SQR:
1391            if (gate_exps) {
1392                // TODO: add support for gated squared relu
1393                GGML_ABORT("fatal error: gated squared relu not implemented");
1394            } else {
1395                cur = ggml_relu(ctx0, cur);
1396                cur = ggml_sqr(ctx0, cur);
1397                cb(cur, "ffn_moe_relu_sqr", il);
1398            } break;
1399        default:
1400            GGML_ABORT("fatal error");
1401    }
1402
1403    experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
1404    cb(experts, "ffn_moe_down", il);
1405
1406    if (down_exps_b) {
1407        experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1408        cb(experts, "ffn_moe_down_biased", il);
1409    }
1410
1411    if (!weight_before_ffn) {
1412        experts = ggml_mul(ctx0, experts, weights);
1413        cb(cur, "ffn_moe_weighted", il);
1414    }
1415
1416    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1417
1418    assert(n_expert_used > 0);
1419
1420    // order the views before the adds
1421    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1422        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1423
1424        ggml_build_forward_expand(gf, cur_experts[i]);
1425    }
1426
1427    // aggregate experts
1428    // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1429    //       to avoid potentially a large number of add nodes during warmup
1430    //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
1431    ggml_tensor * moe_out = cur_experts[0];
1432
1433    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1434        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1435    }
1436
1437    if (hparams.n_expert_used == 1) {
1438        // avoid returning a non-contiguous tensor
1439        moe_out = ggml_cont(ctx0, moe_out);
1440    }
1441
1442    cb(moe_out, "ffn_moe_out", il);
1443
1444    return moe_out;
1445}
1446
1447// input embeddings with optional lora
1448ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1449    const int64_t n_embd_inp = hparams.n_embd_inp();
1450    const int64_t n_embd     = hparams.n_embd;
1451
1452    assert(n_embd_inp >= n_embd);
1453
1454    auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1455
1456    inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1457    cb(inp->tokens, "inp_tokens", -1);
1458    ggml_set_input(inp->tokens);
1459    res->t_inp_tokens = inp->tokens;
1460
1461    inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1462    cb(inp->embd, "inp_embd", -1);
1463    ggml_set_input(inp->embd);
1464
1465    // select one of the 2 inputs, based on the batch contents
1466    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1467    std::array<ggml_tensor *, 2> inps;
1468
1469    // token embeddings path (ubatch.token != nullptr)
1470    {
1471        auto & cur = inps[0];
1472
1473        cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1474
1475        // apply lora for embedding tokens if needed
1476        for (const auto & lora : *loras) {
1477            llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
1478            if (lw == nullptr) {
1479                continue;
1480            }
1481
1482            const float adapter_scale = lora.second;
1483            const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1484
1485            ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
1486                        ctx0, lw->b, // non-transposed lora_b
1487                        ggml_get_rows(ctx0, lw->a, inp->tokens)
1488                        ), scale);
1489
1490            cur = ggml_add(ctx0, cur, inpL_delta);
1491        }
1492
1493        if (n_embd_inp != n_embd) {
1494            cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1495        }
1496    }
1497
1498    // vector embeddings path (ubatch.embd != nullptr)
1499    {
1500        auto & cur = inps[1];
1501
1502        cur = inp->embd;
1503    }
1504
1505    assert(ggml_are_same_shape (inps[0], inps[1]));
1506    assert(ggml_are_same_stride(inps[0], inps[1]));
1507
1508    ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1509
1510    if (n_embd_inp != n_embd) {
1511        cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1512    }
1513
1514    res->t_inp_embd = cur;
1515
1516    // For Granite architecture
1517    if (hparams.f_embedding_scale != 0.0f) {
1518        cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1519    }
1520
1521    cb(cur, "embd", -1);
1522
1523    res->add_input(std::move(inp));
1524
1525    // make sure the produced embeddings are immediately materialized in the ggml graph
1526    // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1527    ggml_build_forward_expand(gf, cur);
1528
1529    return cur;
1530}
1531
1532ggml_tensor * llm_graph_context::build_inp_pos() const {
1533    auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
1534
1535    auto & cur = inp->pos;
1536
1537    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
1538    ggml_set_input(cur);
1539
1540    res->add_input(std::move(inp));
1541
1542    return cur;
1543}
1544
1545ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1546    auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1547
1548    auto & cur = inp->attn_scale;
1549
1550    // this need to be 1x1xN for broadcasting
1551    cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1552    ggml_set_input(cur);
1553
1554    res->add_input(std::move(inp));
1555
1556    return cur;
1557}
1558
1559ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1560    // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1561    //       but this would make the graph topology depend on the number of output tokens, which can interere with
1562    //       features that require constant topology such as pipline parallelism
1563    //       ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1564    //if (n_outputs < n_tokens) {
1565    //    return nullptr;
1566    //}
1567
1568    auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1569
1570    auto & cur = inp->out_ids;
1571
1572    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1573    ggml_set_input(cur);
1574
1575    res->add_input(std::move(inp));
1576
1577    return cur;
1578}
1579
1580ggml_tensor * llm_graph_context::build_inp_mean() const {
1581    auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1582
1583    auto & cur = inp->mean;
1584
1585    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
1586    ggml_set_input(cur);
1587
1588    res->add_input(std::move(inp));
1589
1590    return cur;
1591}
1592
1593ggml_tensor * llm_graph_context::build_inp_cls() const {
1594    auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
1595
1596    auto & cur = inp->cls;
1597
1598    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
1599    ggml_set_input(cur);
1600
1601    res->add_input(std::move(inp));
1602
1603    return cur;
1604}
1605
1606ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1607    auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1608
1609    auto & cur = inp->cross_embd;
1610
1611    // if we have the output embeddings from the encoder, use them directly
1612    // TODO: needs more work to be correct, for now just use the tensor shape
1613    //if (cross->t_embd) {
1614    //    cur = ggml_view_tensor(ctx0, cross->t_embd);
1615
1616    //    return cur;
1617    //}
1618
1619    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1620    const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc  : hparams.n_ctx_train;
1621
1622    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1623    ggml_set_input(cur);
1624
1625    res->add_input(std::move(inp));
1626
1627    return cur;
1628}
1629
1630ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1631    auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
1632
1633    auto & cur = inp->pos_bucket;
1634
1635    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1636    ggml_set_input(cur);
1637
1638    res->add_input(std::move(inp));
1639
1640    return cur;
1641}
1642
1643ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1644    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1645
1646    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1647
1648    const auto n_kv = mctx_cur->get_n_kv();
1649
1650    auto & cur = inp->pos_bucket;
1651
1652    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
1653    ggml_set_input(cur);
1654
1655    res->add_input(std::move(inp));
1656
1657    return cur;
1658}
1659
1660ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1661    ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
1662    cb(pos_bucket_1d, "pos_bucket_1d", -1);
1663
1664    ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1665
1666    pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1667    pos_bias = ggml_permute   (ctx0, pos_bias, 2, 0, 1, 3);
1668    pos_bias = ggml_cont      (ctx0, pos_bias);
1669
1670    cb(pos_bias, "pos_bias", -1);
1671
1672    return pos_bias;
1673}
1674
1675ggml_tensor * llm_graph_context::build_attn_mha(
1676         ggml_tensor * q,
1677         ggml_tensor * k,
1678         ggml_tensor * v,
1679         ggml_tensor * kq_b,
1680         ggml_tensor * kq_mask,
1681         ggml_tensor * sinks,
1682         ggml_tensor * v_mla,
1683               float   kq_scale,
1684                 int   il) const {
1685    const bool v_trans = v->nb[1] > v->nb[2];
1686
1687    // split the batch into streams if needed
1688    const auto n_stream = k->ne[3];
1689
1690    q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1691
1692    q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1693    k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1694    v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1695
1696    ggml_tensor * cur;
1697
1698    if (cparams.flash_attn && kq_b == nullptr) {
1699        GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1700
1701        if (v_trans) {
1702            v = ggml_transpose(ctx0, v);
1703        }
1704
1705        // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1706        if (k->type == GGML_TYPE_F32) {
1707            k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1708        }
1709
1710        if (v->type == GGML_TYPE_F32) {
1711            v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1712        }
1713
1714        cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1715                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1716        cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1717
1718        ggml_flash_attn_ext_add_sinks(cur, sinks);
1719        ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1720
1721        if (v_mla) {
1722#if 0
1723            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1724            // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1725            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1726            cur = ggml_mul_mat(ctx0, v_mla, cur);
1727#else
1728            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1729            // The permutations are noops and only change how the tensor data is interpreted.
1730            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1731            cur = ggml_mul_mat(ctx0, v_mla, cur);
1732            cb(cur, "fattn_mla", il);
1733            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1734            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1735#endif
1736        }
1737
1738        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1739    } else {
1740        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1741        cb(kq, "kq", il);
1742
1743        // note: this op tends to require high floating point range
1744        //       while for some models F16 is enough, for others it is not, so we default to F32 here
1745        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1746
1747        if (arch == LLM_ARCH_GROK) {
1748            // need to do the following:
1749            // multiply by attn_output_multiplier
1750            // and then :
1751            // kq = 30 * tanh(kq / 30)
1752            // before the softmax below
1753
1754            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1755            cb(kq, "kq_tanh", il);
1756            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1757            cb(kq, "kq_scaled", il);
1758        }
1759
1760        if (hparams.attn_soft_cap) {
1761            kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1762            cb(kq, "kq_scaled_1", il);
1763            kq = ggml_tanh (ctx0, kq);
1764            cb(kq, "kq_tanh", il);
1765            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1766            cb(kq, "kq_scaled_2", il);
1767        }
1768
1769        if (kq_b) {
1770            kq = ggml_add(ctx0, kq, kq_b);
1771            cb(kq, "kq_plus_kq_b", il);
1772        }
1773
1774        kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1775        ggml_soft_max_add_sinks(kq, sinks);
1776        cb(kq, "kq_soft_max", il);
1777
1778        if (!v_trans) {
1779            // note: avoid this branch
1780            v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1781            cb(v, "v_cont", il);
1782        }
1783
1784        ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1785        cb(kqv, "kqv", il);
1786
1787        // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1788        if (v_mla) {
1789            kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1790            cb(kqv, "kqv_mla", il);
1791        }
1792
1793        cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1794
1795        // recombine streams
1796        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1797
1798        if (!cparams.offload_kqv) {
1799            // all nodes between the KV store and the attention output are run on the CPU
1800            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1801        }
1802    }
1803
1804    ggml_build_forward_expand(gf, cur);
1805
1806    return cur;
1807}
1808
1809llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1810    auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1811
1812    // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1813    inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1814    ggml_set_input(inp->self_kq_mask);
1815
1816    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1817
1818    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1819        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1820        ggml_set_input(inp->self_kq_mask_swa);
1821
1822        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1823    } else {
1824        inp->self_kq_mask_swa     = nullptr;
1825        inp->self_kq_mask_swa_cnv = nullptr;
1826    }
1827
1828    return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1829}
1830
1831ggml_tensor * llm_graph_context::build_attn(
1832        llm_graph_input_attn_no_cache * inp,
1833        ggml_tensor * wo,
1834        ggml_tensor * wo_b,
1835        ggml_tensor * q_cur,
1836        ggml_tensor * k_cur,
1837        ggml_tensor * v_cur,
1838        ggml_tensor * kq_b,
1839        ggml_tensor * sinks,
1840        ggml_tensor * v_mla,
1841            float     kq_scale,
1842            int       il) const {
1843    GGML_UNUSED(n_tokens);
1844
1845    // these nodes are added to the graph together so that they are not reordered
1846    // by doing so, the number of splits in the graph is reduced
1847    ggml_build_forward_expand(gf, q_cur);
1848    ggml_build_forward_expand(gf, k_cur);
1849    ggml_build_forward_expand(gf, v_cur);
1850
1851    const bool is_swa = hparams.is_swa(il);
1852
1853    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1854
1855    // [TAG_NO_CACHE_PAD]
1856    // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1857    //       but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1858    //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1859
1860    ggml_tensor * q = q_cur;
1861    ggml_tensor * k = k_cur;
1862    ggml_tensor * v = v_cur;
1863
1864    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1865    cb(cur, "kqv_out", il);
1866
1867    if (wo) {
1868        cur = build_lora_mm(wo, cur);
1869    }
1870
1871    if (wo_b) {
1872        //cb(cur, "kqv_wo", il);
1873    }
1874
1875    if (wo_b) {
1876        cur = ggml_add(ctx0, cur, wo_b);
1877    }
1878
1879    return cur;
1880}
1881
1882static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1883           ggml_context * ctx0,
1884     const llama_ubatch & ubatch,
1885    const llama_hparams & hparams,
1886    const llama_cparams & cparams,
1887    const llama_kv_cache_context * mctx_cur) {
1888
1889    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1890
1891    {
1892        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1893
1894        const auto n_kv     = mctx_cur->get_n_kv();
1895        const auto n_tokens = ubatch.n_tokens;
1896        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1897
1898        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1899        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1900
1901        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1902        ggml_set_input(inp->self_kq_mask);
1903
1904        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1905    }
1906
1907    return inp;
1908}
1909
1910llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1911    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1912
1913    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1914
1915    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1916}
1917
1918ggml_tensor * llm_graph_context::build_attn(
1919        llm_graph_input_attn_kv * inp,
1920        ggml_tensor * wo,
1921        ggml_tensor * wo_b,
1922        ggml_tensor * q_cur,
1923        ggml_tensor * k_cur,
1924        ggml_tensor * v_cur,
1925        ggml_tensor * kq_b,
1926        ggml_tensor * sinks,
1927        ggml_tensor * v_mla, // TODO: remove
1928            float     kq_scale,
1929            int       il) const {
1930    GGML_ASSERT(v_mla == nullptr);
1931
1932    // these nodes are added to the graph together so that they are not reordered
1933    // by doing so, the number of splits in the graph is reduced
1934    // expand k later to enable rope fusion which directly writes into k-v cache
1935    ggml_build_forward_expand(gf, q_cur);
1936    ggml_build_forward_expand(gf, v_cur);
1937    ggml_build_forward_expand(gf, k_cur);
1938
1939    const auto * mctx_cur = inp->mctx;
1940
1941    // store to KV cache
1942    {
1943        const auto & k_idxs = inp->get_k_idxs();
1944        const auto & v_idxs = inp->get_v_idxs();
1945
1946        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1947        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1948    }
1949
1950    const auto & kq_mask = inp->get_kq_mask();
1951
1952    ggml_tensor * q = q_cur;
1953    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1954    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1955
1956    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1957    cb(cur, "kqv_out", il);
1958
1959    if (wo) {
1960        cur = build_lora_mm(wo, cur);
1961        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1962            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1963            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1964        }
1965    }
1966
1967    if (wo_b) {
1968        cur = ggml_add(ctx0, cur, wo_b);
1969    }
1970
1971    return cur;
1972}
1973
1974static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
1975           ggml_context * ctx0,
1976     const llama_ubatch & ubatch,
1977    const llama_hparams & hparams,
1978    const llama_cparams & cparams,
1979    const llama_kv_cache_context * mctx_cur) {
1980
1981    auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
1982
1983    {
1984        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1985
1986        const auto n_kv     = mctx_cur->get_n_kv();
1987        const auto n_tokens = ubatch.n_tokens;
1988        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1989
1990        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1991
1992        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1993        ggml_set_input(inp->self_kq_mask);
1994
1995        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1996    }
1997
1998    return inp;
1999}
2000
2001llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2002    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2003
2004    auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2005
2006    return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2007}
2008
2009ggml_tensor * llm_graph_context::build_attn(
2010        llm_graph_input_attn_k * inp,
2011        ggml_tensor * wo,
2012        ggml_tensor * wo_b,
2013        ggml_tensor * q_cur,
2014        ggml_tensor * k_cur,
2015        ggml_tensor * v_cur,
2016        ggml_tensor * kq_b,
2017        ggml_tensor * sinks,
2018        ggml_tensor * v_mla,
2019            float     kq_scale,
2020            int       il) const {
2021    // these nodes are added to the graph together so that they are not reordered
2022    // by doing so, the number of splits in the graph is reduced
2023    // expand k later to enable rope fusion which directly writes into k-v cache
2024    ggml_build_forward_expand(gf, q_cur);
2025    ggml_build_forward_expand(gf, v_cur);
2026    ggml_build_forward_expand(gf, k_cur);
2027
2028    const auto * mctx_cur = inp->mctx;
2029
2030    // store to KV cache
2031    {
2032        const auto & k_idxs = inp->get_k_idxs();
2033
2034        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2035    }
2036
2037    const auto & kq_mask = inp->get_kq_mask();
2038
2039    ggml_tensor * q = q_cur;
2040    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2041    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2042
2043    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2044    cb(cur, "kqv_out", il);
2045
2046    if (wo) {
2047        cur = build_lora_mm(wo, cur);
2048        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
2049            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2050            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2051        }
2052    }
2053
2054    if (wo_b) {
2055        cur = ggml_add(ctx0, cur, wo_b);
2056    }
2057
2058    return cur;
2059}
2060
2061ggml_tensor * llm_graph_context::build_attn(
2062        llm_graph_input_attn_kv_iswa * inp,
2063        ggml_tensor * wo,
2064        ggml_tensor * wo_b,
2065        ggml_tensor * q_cur,
2066        ggml_tensor * k_cur,
2067        ggml_tensor * v_cur,
2068        ggml_tensor * kq_b,
2069        ggml_tensor * sinks,
2070        ggml_tensor * v_mla,
2071            float     kq_scale,
2072            int       il) const {
2073    // these nodes are added to the graph together so that they are not reordered
2074    // by doing so, the number of splits in the graph is reduced
2075    ggml_build_forward_expand(gf, q_cur);
2076
2077    if (k_cur) {
2078        ggml_build_forward_expand(gf, k_cur);
2079    }
2080
2081    if (v_cur) {
2082        ggml_build_forward_expand(gf, v_cur);
2083    }
2084
2085    const auto * mctx_iswa = inp->mctx;
2086
2087    const bool is_swa = hparams.is_swa(il);
2088
2089    const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
2090
2091    // optionally store to KV cache
2092    if (k_cur) {
2093        const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
2094
2095        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2096    }
2097
2098    if (v_cur) {
2099        const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
2100
2101        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
2102    }
2103
2104    const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
2105
2106    ggml_tensor * q = q_cur;
2107    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2108    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
2109
2110    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2111    cb(cur, "kqv_out", il);
2112
2113    if (wo) {
2114        cur = build_lora_mm(wo, cur);
2115    }
2116
2117    if (wo_b) {
2118        //cb(cur, "kqv_wo", il);
2119    }
2120
2121    if (wo_b) {
2122        cur = ggml_add(ctx0, cur, wo_b);
2123    }
2124
2125    return cur;
2126}
2127
2128llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
2129    auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
2130
2131    const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
2132
2133    inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
2134    ggml_set_input(inp->cross_kq_mask);
2135
2136    inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
2137
2138    return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
2139}
2140
2141ggml_tensor * llm_graph_context::build_attn(
2142        llm_graph_input_attn_cross * inp,
2143        ggml_tensor * wo,
2144        ggml_tensor * wo_b,
2145        ggml_tensor * q_cur,
2146        ggml_tensor * k_cur,
2147        ggml_tensor * v_cur,
2148        ggml_tensor * kq_b,
2149        ggml_tensor * sinks,
2150        ggml_tensor * v_mla,
2151            float     kq_scale,
2152            int       il) const {
2153    // these nodes are added to the graph together so that they are not reordered
2154    // by doing so, the number of splits in the graph is reduced
2155    ggml_build_forward_expand(gf, q_cur);
2156    ggml_build_forward_expand(gf, k_cur);
2157    ggml_build_forward_expand(gf, v_cur);
2158
2159    const auto & kq_mask = inp->get_kq_mask_cross();
2160
2161    ggml_tensor * q = q_cur;
2162    ggml_tensor * k = k_cur;
2163    ggml_tensor * v = v_cur;
2164
2165    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2166    cb(cur, "kqv_out", il);
2167
2168    if (wo) {
2169        cur = build_lora_mm(wo, cur);
2170    }
2171
2172    if (wo_b) {
2173        //cb(cur, "kqv_wo", il);
2174    }
2175
2176    if (wo_b) {
2177        cur = ggml_add(ctx0, cur, wo_b);
2178    }
2179
2180    return cur;
2181}
2182
2183// TODO: maybe separate the inner implementation into a separate function
2184//       like with the non-sliding window equivalent
2185//       once sliding-window hybrid caches are a thing.
2186llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
2187    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
2188
2189    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
2190
2191    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
2192
2193    {
2194        const auto n_kv = mctx_cur->get_base()->get_n_kv();
2195
2196        inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
2197        inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
2198
2199        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2200        ggml_set_input(inp->self_kq_mask);
2201        ggml_set_name(inp->self_kq_mask, "self_kq_mask");
2202
2203        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2204        ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
2205    }
2206
2207    {
2208        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
2209
2210        const auto n_kv = mctx_cur->get_swa()->get_n_kv();
2211
2212        inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
2213        inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
2214
2215        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2216        ggml_set_input(inp->self_kq_mask_swa);
2217        ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
2218
2219        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
2220        ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
2221    }
2222
2223    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
2224}
2225
2226ggml_tensor * llm_graph_context::build_rs(
2227        ggml_tensor * s,
2228        ggml_tensor * state_copy_main,
2229        ggml_tensor * state_copy_extra,
2230            int32_t   state_size,
2231            int32_t   n_seqs,
2232           uint32_t   n_rs,
2233           uint32_t   rs_head,
2234           uint32_t   rs_size,
2235            int32_t   rs_zero,
2236        const llm_graph_get_rows_fn & get_state_rows) const {
2237
2238    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
2239
2240    // Clear a single state which will then be copied to the other cleared states.
2241    // Note that this is a no-op when the view is zero-sized.
2242    ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
2243    ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
2244
2245    // copy states
2246    // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
2247    // {state_size, rs_size} -> {state_size, n_seqs}
2248    ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
2249    ggml_build_forward_expand(gf, output_states);
2250
2251    // copy extra states which won't be changed further (between n_seqs and n_rs)
2252    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
2253    ggml_build_forward_expand(gf,
2254        ggml_cpy(ctx0,
2255            states_extra,
2256            ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
2257
2258    return output_states;
2259}
2260
2261static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
2262           ggml_context * ctx0,
2263     const llama_ubatch & ubatch,
2264    const llama_memory_recurrent_context * mctx_cur) {
2265
2266    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
2267
2268    const int64_t n_rs   = mctx_cur->get_n_rs();
2269    const int64_t n_seqs = ubatch.n_seqs;
2270
2271    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
2272    ggml_set_input(inp->s_copy);
2273
2274    inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
2275    inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
2276
2277    inp->head = mctx_cur->get_head();
2278    inp->rs_z = mctx_cur->get_rs_z();
2279
2280    return inp;
2281}
2282
2283llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
2284    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2285
2286    auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2287
2288    return (llm_graph_input_rs *) res->add_input(std::move(inp));
2289}
2290
2291ggml_tensor * llm_graph_context::build_rs(
2292        llm_graph_input_rs * inp,
2293        ggml_tensor * s,
2294            int32_t   state_size,
2295            int32_t   n_seqs,
2296        const llm_graph_get_rows_fn & get_state_rows) const {
2297    const auto * kv_state = inp->mctx;
2298
2299    return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2300                    kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2301                    get_state_rows);
2302}
2303
2304ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
2305    llm_graph_input_rs * inp,
2306    const llama_ubatch & ubatch,
2307                   int   il) const {
2308    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2309
2310    const auto token_shift_count = hparams.token_shift_count;
2311
2312    const int64_t n_seqs  = ubatch.n_seqs;
2313
2314    ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
2315
2316    ggml_tensor * token_shift = build_rs(
2317            inp, token_shift_all,
2318            hparams.n_embd_r(), n_seqs);
2319
2320    token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
2321
2322    return token_shift;
2323}
2324
2325ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
2326         ggml_tensor * token_shift,
2327  const llama_ubatch & ubatch,
2328                 int   il) const {
2329    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
2330
2331    const auto token_shift_count = hparams.token_shift_count;
2332    const auto n_embd = hparams.n_embd;
2333
2334    const int64_t n_seqs = ubatch.n_seqs;
2335
2336    const auto kv_head = mctx_cur->get_head();
2337
2338    return ggml_cpy(
2339        ctx0,
2340        ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2341        ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
2342    );
2343}
2344
2345llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2346    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2347
2348    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2349    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2350
2351    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2352
2353    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2354}
2355
2356llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2357    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2358
2359    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
2360    auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2361
2362    auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2363
2364    return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2365}
2366
2367llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2368    const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2369
2370    auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2371
2372    // build iswa attention input
2373    const auto * attn_ctx = mctx_cur->get_attn();
2374
2375    auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2376
2377    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
2378
2379    {
2380        const auto n_kv = attn_ctx->get_base()->get_n_kv();
2381
2382        inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2383        inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2384
2385        inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2386        ggml_set_input(inp_attn->self_kq_mask);
2387
2388        inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2389    }
2390
2391    {
2392        const auto n_kv = attn_ctx->get_swa()->get_n_kv();
2393
2394        inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2395        inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2396
2397        inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2398        ggml_set_input(inp_attn->self_kq_mask_swa);
2399
2400        inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2401    }
2402
2403    auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2404
2405    return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2406}
2407
2408void llm_graph_context::build_dense_out(
2409    ggml_tensor * dense_2,
2410    ggml_tensor * dense_3) const {
2411    if (!cparams.embeddings || !(dense_2 || dense_3)) {
2412        return;
2413    }
2414    ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2415    GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2416
2417    if (dense_2) {
2418        cur = ggml_mul_mat(ctx0, dense_2, cur);
2419    }
2420    if (dense_3) {
2421        cur = ggml_mul_mat(ctx0, dense_3, cur);
2422    }
2423    cb(cur, "result_embd_pooled", -1);
2424    res->t_embd_pooled = cur;
2425    ggml_build_forward_expand(gf, cur);
2426}
2427
2428
2429void llm_graph_context::build_pooling(
2430        ggml_tensor * cls,
2431        ggml_tensor * cls_b,
2432        ggml_tensor * cls_out,
2433        ggml_tensor * cls_out_b) const {
2434    if (!cparams.embeddings) {
2435        return;
2436    }
2437
2438    ggml_tensor * inp = res->t_embd;
2439
2440    //// find result_norm tensor for input
2441    //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
2442    //    inp = ggml_graph_node(gf, i);
2443    //    if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
2444    //        break;
2445    //    }
2446
2447    //    inp = nullptr;
2448    //}
2449
2450    GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
2451
2452    ggml_tensor * cur;
2453
2454    switch (pooling_type) {
2455        case LLAMA_POOLING_TYPE_NONE:
2456            {
2457                cur = inp;
2458            } break;
2459        case LLAMA_POOLING_TYPE_MEAN:
2460            {
2461                ggml_tensor * inp_mean = build_inp_mean();
2462                cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2463            } break;
2464        case LLAMA_POOLING_TYPE_CLS:
2465        case LLAMA_POOLING_TYPE_LAST:
2466            {
2467                ggml_tensor * inp_cls = build_inp_cls();
2468                cur = ggml_get_rows(ctx0, inp, inp_cls);
2469            } break;
2470        case LLAMA_POOLING_TYPE_RANK:
2471            {
2472                ggml_tensor * inp_cls = build_inp_cls();
2473                cur = ggml_get_rows(ctx0, inp, inp_cls);
2474
2475                // classification head
2476                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2477                if (cls) {
2478                    cur = ggml_mul_mat(ctx0, cls, cur);
2479                    if (cls_b) {
2480                        cur = ggml_add(ctx0, cur, cls_b);
2481                    }
2482                    cur = ggml_tanh(ctx0, cur);
2483                }
2484
2485                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2486                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2487                // Single layer classification head (direct projection)
2488                // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
2489                if (cls_out) {
2490                    cur = ggml_mul_mat(ctx0, cls_out, cur);
2491                    if (cls_out_b) {
2492                        cur = ggml_add(ctx0, cur, cls_out_b);
2493                    }
2494                }
2495
2496                // softmax for qwen3 reranker
2497                if (arch == LLM_ARCH_QWEN3) {
2498                    cur = ggml_soft_max(ctx0, cur);
2499                }
2500            } break;
2501        default:
2502            {
2503                GGML_ABORT("unknown pooling type");
2504            }
2505    }
2506
2507    cb(cur, "result_embd_pooled", -1);
2508    res->t_embd_pooled = cur;
2509
2510    ggml_build_forward_expand(gf, cur);
2511}
2512
2513void llm_graph_context::build_sampling() const {
2514    if (samplers.empty() || !res->t_logits) {
2515        return;
2516    }
2517
2518    std::array<ggml_tensor *, 2> outs;
2519    outs[0] = res->t_logits;
2520
2521    auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2522    res->add_input(std::move(inp_sampling));
2523
2524    std::map<llama_seq_id, int32_t> seq_to_logit_row;
2525    int32_t logit_row_idx = 0;
2526
2527    for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2528        if (ubatch.output[i]) {
2529            llama_seq_id seq_id = ubatch.seq_id[i][0];
2530            seq_to_logit_row[seq_id] = logit_row_idx;
2531            logit_row_idx++;
2532        }
2533    }
2534
2535    // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2536    GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2537
2538    // add a dummy row of logits
2539    // this trick makes the graph static, regardless of which samplers are activated
2540    // this is important in order to minimize graph reallocations
2541    ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2542
2543    for (const auto & [seq_id, sampler] : samplers) {
2544        const auto it = seq_to_logit_row.find(seq_id);
2545
2546        // inactive samplers always work on the first row
2547        const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2548        const int i_out    = it != seq_to_logit_row.end() ? 1          : 0;
2549
2550        ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2551        ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2552
2553        struct llama_sampler_data data = {
2554            /*.logits      =*/ logits_seq,
2555            /*.probs       =*/ nullptr,
2556            /*.sampled     =*/ nullptr,
2557            /*.candidates  =*/ nullptr,
2558        };
2559
2560        assert(sampler->iface->backend_apply);
2561        sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2562
2563        if (data.sampled != nullptr) {
2564            res->t_sampled[seq_id] = data.sampled;
2565            outs[1] = data.sampled;
2566            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2567        }
2568
2569        if (data.probs != nullptr) {
2570            res->t_sampled_probs[seq_id] = data.probs;
2571            outs[1] = data.probs;
2572            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2573        }
2574
2575        if (data.logits != nullptr) {
2576            res->t_sampled_logits[seq_id] = data.logits;
2577            outs[1] = data.logits;
2578            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2579        }
2580
2581        if (data.candidates != nullptr) {
2582            res->t_candidates[seq_id] = data.candidates;
2583            outs[1] = data.candidates;
2584            ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2585        }
2586    }
2587
2588    // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2589    /*
2590    for (const auto & [seq_id, sampler] : samplers) {
2591        if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2592            ggml_tensor * selected_token = it->second;
2593            if (selected_token != nullptr) {
2594                llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2595            }
2596        }
2597    }
2598    */
2599}
2600
2601int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2602    // TODO move to hparams if a T5 variant appears that uses a different value
2603    const int64_t max_distance = 128;
2604
2605    if (bidirectional) {
2606        n_buckets >>= 1;
2607    }
2608
2609    const int64_t max_exact = n_buckets >> 1;
2610
2611    int32_t relative_position = x - y;
2612    int32_t relative_bucket = 0;
2613
2614    if (bidirectional) {
2615        relative_bucket += (relative_position > 0) * n_buckets;
2616        relative_position = std::abs(relative_position);
2617    } else {
2618        relative_position = -std::min<int32_t>(relative_position, 0);
2619    }
2620
2621    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
2622    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
2623    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2624
2625    return relative_bucket;
2626}