1#include "ggml-metal-device.h"
   2
   3#include "ggml-metal-impl.h"
   4
   5#include "ggml-impl.h"
   6
   7#include <cassert>
   8#include <memory>
   9#include <string>
  10#include <unordered_map>
  11
  12struct ggml_metal_device_deleter {
  13    void operator()(ggml_metal_device_t ctx) {
  14        ggml_metal_device_free(ctx);
  15    }
  16};
  17
  18typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
  19
  20ggml_metal_device_t ggml_metal_device_get(int device) {
  21    static std::vector<ggml_metal_device_ptr> devs;
  22
  23    devs.emplace_back(ggml_metal_device_init(device));
  24
  25    return devs.back().get();
  26}
  27
  28struct ggml_metal_pipelines {
  29    std::unordered_map<std::string, ggml_metal_pipeline_t> data;
  30};
  31
  32ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
  33    ggml_metal_pipelines_t res = new ggml_metal_pipelines();
  34
  35    return res;
  36}
  37
  38void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
  39    if (!ppls) {
  40        return;
  41    }
  42
  43    for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
  44        ggml_metal_pipeline_free(it->second);
  45    }
  46
  47    delete ppls;
  48}
  49
  50void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {
  51    ppls->data[name] = pipeline;
  52}
  53
  54ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
  55    if (ppls->data.find(name) == ppls->data.end()) {
  56        return nullptr;
  57    }
  58
  59    return ppls->data[name];
  60}
  61
  62struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
  63    char base[256];
  64    char name[256];
  65
  66    const char * op_str = "undefined";
  67    switch (op) {
  68        case GGML_OP_ADD_ID: op_str = "add_id"; break;
  69        case GGML_OP_CONCAT: op_str = "concat"; break;
  70        default: GGML_ABORT("fatal error");
  71    };
  72
  73    snprintf(base, 256, "kernel_%s", op_str);
  74    snprintf(name, 256, "%s", base);
  75
  76    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  77    if (!res.pipeline) {
  78        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  79    }
  80
  81    return res;
  82}
  83
  84ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
  85    char base[256];
  86    char name[256];
  87
  88    snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
  89    snprintf(name, 256, "%s", base);
  90
  91    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
  92    if (!res.pipeline) {
  93        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
  94    }
  95
  96    return res;
  97}
  98
  99ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
 100    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 101    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
 102
 103    const char * pool_str = "undefined";
 104    switch (op_pool) {
 105        case GGML_OP_POOL_AVG: pool_str = "avg"; break;
 106        case GGML_OP_POOL_MAX: pool_str = "max"; break;
 107        default: GGML_ASSERT(false && "not implemented");
 108    };
 109
 110    char base[256];
 111    char name[256];
 112
 113    snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
 114    snprintf(name, sizeof(name), "%s", base);
 115
 116    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 117    if (!res.pipeline) {
 118        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 119    }
 120
 121    return res;
 122}
 123
 124ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
 125    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 126    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
 127
 128    const char * pool_str = "undefined";
 129    switch (op_pool) {
 130        case GGML_OP_POOL_AVG: pool_str = "avg"; break;
 131        case GGML_OP_POOL_MAX: pool_str = "max"; break;
 132        default: GGML_ASSERT(false && "not implemented");
 133    };
 134
 135    char base[256];
 136    char name[256];
 137
 138    snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
 139    snprintf(name, 256, "%s", base);
 140
 141    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 142    if (!res.pipeline) {
 143        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 144    }
 145
 146    return res;
 147}
 148
 149ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
 150    char base[256];
 151    char name[256];
 152
 153    snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
 154    snprintf(name, 256, "%s", base);
 155
 156    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 157    if (!res.pipeline) {
 158        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 159    }
 160
 161    return res;
 162}
 163
 164ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
 165    char base[256];
 166    char name[256];
 167
 168    snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
 169    snprintf(name, 256, "%s", base);
 170
 171    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 172    if (!res.pipeline) {
 173        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 174    }
 175
 176    return res;
 177}
 178
 179ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
 180    char base[256];
 181    char name[256];
 182
 183    const int n = op->src[0]->ne[0];
 184
 185    snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
 186    snprintf(name, 256, "%s_n=%d", base, n);
 187
 188    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 189    if (!res.pipeline) {
 190        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 191    }
 192
 193    res.nsg  = 1;
 194    res.smem = 0;
 195
 196    return res;
 197}
 198
 199ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
 200    char base[256];
 201    char name[256];
 202
 203    snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
 204    snprintf(name, 256, "%s", base);
 205
 206    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 207    if (!res.pipeline) {
 208        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 209    }
 210
 211    return res;
 212}
 213
 214ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
 215    char base[256];
 216    char name[256];
 217
 218    int op_num = -1;
 219
 220    switch (op->op) {
 221        case GGML_OP_SCALE:      op_num = OP_UNARY_NUM_SCALE;      break;
 222        case GGML_OP_FILL:       op_num = OP_UNARY_NUM_FILL;       break;
 223        case GGML_OP_CLAMP:      op_num = OP_UNARY_NUM_CLAMP;      break;
 224        case GGML_OP_SQR:        op_num = OP_UNARY_NUM_SQR;        break;
 225        case GGML_OP_SQRT:       op_num = OP_UNARY_NUM_SQRT;       break;
 226        case GGML_OP_SIN:        op_num = OP_UNARY_NUM_SIN;        break;
 227        case GGML_OP_COS:        op_num = OP_UNARY_NUM_COS;        break;
 228        case GGML_OP_LOG:        op_num = OP_UNARY_NUM_LOG;        break;
 229        case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
 230        case GGML_OP_UNARY:
 231            switch (ggml_get_unary_op(op)) {
 232                case GGML_UNARY_OP_TANH:        op_num = OP_UNARY_NUM_TANH;        break;
 233                case GGML_UNARY_OP_RELU:        op_num = OP_UNARY_NUM_RELU;        break;
 234                case GGML_UNARY_OP_SIGMOID:     op_num = OP_UNARY_NUM_SIGMOID;     break;
 235                case GGML_UNARY_OP_GELU:        op_num = OP_UNARY_NUM_GELU;        break;
 236                case GGML_UNARY_OP_GELU_ERF:    op_num = OP_UNARY_NUM_GELU_ERF;    break;
 237                case GGML_UNARY_OP_GELU_QUICK:  op_num = OP_UNARY_NUM_GELU_QUICK;  break;
 238                case GGML_UNARY_OP_SILU:        op_num = OP_UNARY_NUM_SILU;        break;
 239                case GGML_UNARY_OP_ELU:         op_num = OP_UNARY_NUM_ELU;         break;
 240                case GGML_UNARY_OP_NEG:         op_num = OP_UNARY_NUM_NEG;         break;
 241                case GGML_UNARY_OP_ABS:         op_num = OP_UNARY_NUM_ABS;         break;
 242                case GGML_UNARY_OP_SGN:         op_num = OP_UNARY_NUM_SGN;         break;
 243                case GGML_UNARY_OP_STEP:        op_num = OP_UNARY_NUM_STEP;        break;
 244                case GGML_UNARY_OP_HARDSWISH:   op_num = OP_UNARY_NUM_HARDSWISH;   break;
 245                case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
 246                case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;
 247                case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;
 248                case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;
 249                default: GGML_ABORT("fatal error");
 250            } break;
 251        default: GGML_ABORT("fatal error");
 252    };
 253
 254    const char * t0_str = ggml_type_name(op->src[0]->type);
 255    const char * t_str  = ggml_type_name(op->type);
 256
 257    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
 258    const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
 259
 260    snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
 261    snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
 262
 263    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 264    if (!res.pipeline) {
 265        ggml_metal_cv_t cv = ggml_metal_cv_init();
 266
 267        ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
 268        ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
 269
 270        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 271
 272        ggml_metal_cv_free(cv);
 273    }
 274
 275    res.c4  = is_c4;
 276    res.cnt = is_cnt;
 277
 278    return res;
 279}
 280
 281ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
 282    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
 283
 284    char base[256];
 285    char name[256];
 286
 287    const char * op_str = "undefined";
 288    switch (op->op) {
 289        case GGML_OP_GLU:
 290            switch (ggml_get_glu_op(op)) {
 291                case GGML_GLU_OP_REGLU:        op_str = "reglu";        break;
 292                case GGML_GLU_OP_GEGLU:        op_str = "geglu";        break;
 293                case GGML_GLU_OP_SWIGLU:       op_str = "swiglu";       break;
 294                case GGML_GLU_OP_SWIGLU_OAI:   op_str = "swiglu_oai";   break;
 295                case GGML_GLU_OP_GEGLU_ERF:    op_str = "geglu_erf";    break;
 296                case GGML_GLU_OP_GEGLU_QUICK:  op_str = "geglu_quick";  break;
 297                default: GGML_ABORT("fatal error");
 298            } break;
 299        default: GGML_ABORT("fatal error");
 300    };
 301
 302    snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
 303    snprintf(name, 256, "%s", base);
 304
 305    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 306    if (!res.pipeline) {
 307        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 308    }
 309
 310    return res;
 311}
 312
 313ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
 314    assert(op->op == GGML_OP_SUM);
 315
 316    char base[256];
 317    char name[256];
 318
 319    snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
 320    snprintf(name, 256, "%s", base);
 321
 322    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 323    if (!res.pipeline) {
 324        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 325    }
 326
 327    return res;
 328}
 329
 330ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
 331    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
 332
 333    char base[256];
 334    char name[256];
 335
 336    const char * op_str = "undefined";
 337    switch (op->op) {
 338        case GGML_OP_SUM_ROWS:
 339            op_str = "sum_rows"; break;
 340        case GGML_OP_MEAN:
 341            op_str = "mean"; break;
 342        default: GGML_ABORT("fatal error");
 343    };
 344
 345    snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
 346
 347    snprintf(name, 256, "%s", base);
 348
 349    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 350    if (!res.pipeline) {
 351        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 352    }
 353
 354    res.smem = 32*sizeof(float);
 355
 356    return res;
 357}
 358
 359ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
 360    GGML_ASSERT(op->op == GGML_OP_CUMSUM);
 361
 362    char base[256];
 363    char name[256];
 364
 365    snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
 366    snprintf(name, 256, "%s", base);
 367
 368    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 369    if (!res.pipeline) {
 370        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 371    }
 372
 373    return res;
 374}
 375
 376ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
 377    GGML_ASSERT(op->op == GGML_OP_CUMSUM);
 378
 379    char base[256];
 380    char name[256];
 381
 382    snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
 383    snprintf(name, 256, "%s", base);
 384
 385    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 386    if (!res.pipeline) {
 387        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 388    }
 389
 390    return res;
 391}
 392
 393ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
 394    GGML_ASSERT(op->op == GGML_OP_TRI);
 395    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
 396
 397    char base[256];
 398    char name[256];
 399
 400    const char * op_str = "tri";
 401    const int ttype = op->op_params[0];
 402
 403    snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
 404
 405    snprintf(name, 256, "%s", base);
 406
 407    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 408    if (!res.pipeline) {
 409        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 410    }
 411
 412    return res;
 413}
 414
 415ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
 416    GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
 417
 418    char base[256];
 419    char name[256];
 420
 421    const char * suffix = "";
 422
 423    if (op->src[0]->ne[0] % 4 == 0) {
 424        suffix = "_4";
 425    }
 426
 427    const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32;
 428
 429    snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
 430    snprintf(name, 256, "%s", base);
 431
 432    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 433    if (!res.pipeline) {
 434        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 435    }
 436
 437    res.smem = 32*sizeof(float);
 438
 439    return res;
 440}
 441
 442ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
 443    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
 444    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
 445
 446    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 447    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
 448
 449    char base[256];
 450    char name[256];
 451
 452    const char * suffix = "";
 453
 454    if (op->src[1]->ne[0] % 4 == 0) {
 455        suffix = "_4";
 456    }
 457
 458    snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
 459    snprintf(name, 256, "%s", base);
 460
 461    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 462    if (!res.pipeline) {
 463        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 464    }
 465
 466    return res;
 467}
 468
 469ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
 470    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
 471    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
 472
 473    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 474    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
 475
 476    char base[256];
 477    char name[256];
 478
 479    const char * suffix = "";
 480    if (op->src[1]->ne[0] % 4 == 0) {
 481        suffix = "_4";
 482    }
 483
 484    snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
 485    snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
 486
 487    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 488    if (!res.pipeline) {
 489        ggml_metal_cv_t cv = ggml_metal_cv_init();
 490
 491        ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
 492
 493        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 494
 495        ggml_metal_cv_free(cv);
 496    }
 497
 498    return res;
 499}
 500
 501ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {
 502    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 503
 504    char base[256];
 505    char name[256];
 506
 507    const int nsg = (ne00 + 31)/32;
 508
 509    snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
 510    snprintf(name, 256, "%s_nsg=%d", base, nsg);
 511
 512    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 513    if (!res.pipeline) {
 514        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 515    }
 516
 517    // Shared memory layout:
 518    // - sgptg * NW floats for partial sums (nsg * 32)
 519    // - sgptg floats for shared_x_dt (nsg)
 520    // - sgptg floats for shared_dA (nsg)
 521    // Total: nsg * (32 + 2) floats
 522    res.smem = (32 + 2)*sizeof(float)*nsg;
 523
 524    return res;
 525}
 526
 527ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
 528    char base[256];
 529    char name[256];
 530
 531    const int64_t C = op->ne[0];
 532    const int64_t H = op->src[0]->ne[1];
 533
 534    switch (op->op) {
 535        case GGML_OP_RWKV_WKV6:
 536            {
 537                GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
 538                GGML_ASSERT(C % H == 0);
 539                GGML_ASSERT(C / H == 64);
 540
 541                snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type));
 542            } break;
 543        case GGML_OP_RWKV_WKV7:
 544            {
 545                GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32);
 546                GGML_ASSERT(C % H == 0);
 547                GGML_ASSERT(C / H == 64);
 548
 549                snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type));
 550            } break;
 551        default:
 552            GGML_ABORT("fatal error");
 553    }
 554
 555    snprintf(name, 256, "%s", base);
 556
 557    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 558    if (!res.pipeline) {
 559        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 560    }
 561
 562    return res;
 563}
 564
 565ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
 566    char base[256];
 567    char name[256];
 568
 569    const int nsg = 8;
 570    const int n   = op->src[1]->ne[1];
 571    const int k   = op->src[1]->ne[0];
 572
 573    snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
 574    snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
 575
 576    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 577    if (!res.pipeline) {
 578        ggml_metal_cv_t cv = ggml_metal_cv_init();
 579
 580        ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
 581        ggml_metal_cv_set_int16(cv, n,   FC_SOLVE_TRI + 1);
 582        ggml_metal_cv_set_int16(cv, k,   FC_SOLVE_TRI + 2);
 583
 584        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 585
 586        ggml_metal_cv_free(cv);
 587    }
 588
 589    res.nsg  = nsg;
 590    res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
 591
 592    return res;
 593}
 594
 595ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
 596    char base[256];
 597    char name[256];
 598
 599    snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
 600    snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
 601
 602    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 603    if (!res.pipeline) {
 604        ggml_metal_cv_t cv = ggml_metal_cv_init();
 605
 606        ggml_metal_cv_set_int16(cv, nsg,   FC_MUL_MV + 0);
 607        ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
 608
 609        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 610
 611        ggml_metal_cv_free(cv);
 612    }
 613
 614    return res;
 615}
 616
 617ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
 618    char base[256];
 619    char name[256];
 620
 621    const ggml_type tsrc0 = op->src[0]->type;
 622    const ggml_type tsrc1 = op->src[1]->type;
 623
 624    const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
 625    const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
 626
 627    snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
 628    snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
 629
 630    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 631    if (!res.pipeline) {
 632        ggml_metal_cv_t cv = ggml_metal_cv_init();
 633
 634        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
 635        ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
 636
 637        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 638
 639        ggml_metal_cv_free(cv);
 640    }
 641
 642    // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
 643    res.smem = bc_out ? 8192 : 4096 + 2048;
 644
 645    return res;
 646}
 647
 648ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
 649    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 650    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 651
 652    char base[256];
 653    char name[256];
 654
 655    int nsg = 0; // number of simdgroups
 656    int nr0 = 0; // number of src0 rows per simdgroup
 657    int nr1 = 1; // number of src1 rows per threadgroup
 658
 659    size_t smem = 0; // shared memory
 660
 661    const ggml_type tsrc0 = op->src[0]->type;
 662    const ggml_type tsrc1 = op->src[1]->type;
 663
 664    const char * suffix = "";
 665
 666    // use custom matrix x vector kernel
 667    switch (tsrc0) {
 668        case GGML_TYPE_F32:
 669        case GGML_TYPE_F16:
 670        case GGML_TYPE_BF16:
 671            {
 672                if (ne00 < 32) {
 673                    nsg = 1;
 674                    nr0 = 32;
 675                    nr1 = 1;
 676                    suffix = "_short";
 677                } else {
 678                    nsg = std::min(4, (ne00 + 127) / 128);
 679                    nr0 = 2;
 680                    nr1 = 1;
 681                    smem = 32*sizeof(float)*nr0;
 682                    suffix = ne00 % 4 == 0 ? "_4" : "";
 683                }
 684            } break;
 685        case GGML_TYPE_Q4_0:
 686            {
 687                nsg = N_SG_Q4_0;
 688                nr0 = N_R0_Q4_0;
 689            } break;
 690        case GGML_TYPE_Q4_1:
 691            {
 692                nsg = N_SG_Q4_1;
 693                nr0 = N_R0_Q4_1;
 694            } break;
 695        case GGML_TYPE_Q5_0:
 696            {
 697                nsg = N_SG_Q5_0;
 698                nr0 = N_R0_Q5_0;
 699            } break;
 700        case GGML_TYPE_Q5_1:
 701            {
 702                nsg = N_SG_Q5_1;
 703                nr0 = N_R0_Q5_1;
 704            } break;
 705        case GGML_TYPE_Q8_0:
 706            {
 707                nsg = N_SG_Q8_0;
 708                nr0 = N_R0_Q8_0;
 709                smem = 32*sizeof(float)*N_R0_Q8_0;
 710            } break;
 711        case GGML_TYPE_MXFP4:
 712            {
 713                nsg = N_SG_MXFP4;
 714                nr0 = N_R0_MXFP4;
 715                smem = 32*sizeof(float);
 716            } break;
 717        case GGML_TYPE_Q2_K:
 718            {
 719                nsg = N_SG_Q2_K;
 720                nr0 = N_R0_Q2_K;
 721            } break;
 722        case GGML_TYPE_Q3_K:
 723            {
 724                nsg = N_SG_Q3_K;
 725                nr0 = N_R0_Q3_K;
 726            } break;
 727        case GGML_TYPE_Q4_K:
 728            {
 729                nsg = N_SG_Q4_K;
 730                nr0 = N_R0_Q4_K;
 731            } break;
 732        case GGML_TYPE_Q5_K:
 733            {
 734                nsg = N_SG_Q5_K;
 735                nr0 = N_R0_Q5_K;
 736            } break;
 737        case GGML_TYPE_Q6_K:
 738            {
 739                nsg = N_SG_Q6_K;
 740                nr0 = N_R0_Q6_K;
 741            } break;
 742        case GGML_TYPE_IQ2_XXS:
 743            {
 744                nsg = N_SG_IQ2_XXS;
 745                nr0 = N_R0_IQ2_XXS;
 746                smem = 256*8+128;
 747            } break;
 748        case GGML_TYPE_IQ2_XS:
 749            {
 750                nsg = N_SG_IQ2_XS;
 751                nr0 = N_R0_IQ2_XS;
 752                smem = 512*8+128;
 753            } break;
 754        case GGML_TYPE_IQ3_XXS:
 755            {
 756                nsg = N_SG_IQ3_XXS;
 757                nr0 = N_R0_IQ3_XXS;
 758                smem = 256*4+128;
 759            } break;
 760        case GGML_TYPE_IQ3_S:
 761            {
 762                nsg = N_SG_IQ3_S;
 763                nr0 = N_R0_IQ3_S;
 764                smem = 512*4;
 765            } break;
 766        case GGML_TYPE_IQ2_S:
 767            {
 768                nsg = N_SG_IQ2_S;
 769                nr0 = N_R0_IQ2_S;
 770            } break;
 771        case GGML_TYPE_IQ1_S:
 772            {
 773                nsg = N_SG_IQ1_S;
 774                nr0 = N_R0_IQ1_S;
 775            } break;
 776        case GGML_TYPE_IQ1_M:
 777            {
 778                nsg = N_SG_IQ1_M;
 779                nr0 = N_R0_IQ1_M;
 780            } break;
 781        case GGML_TYPE_IQ4_NL:
 782            {
 783                nsg = N_SG_IQ4_NL;
 784                nr0 = N_R0_IQ4_NL;
 785                smem = 32*sizeof(float);
 786            } break;
 787        case GGML_TYPE_IQ4_XS:
 788            {
 789                nsg = N_SG_IQ4_XS;
 790                nr0 = N_R0_IQ4_XS;
 791                smem = 32*sizeof(float);
 792            } break;
 793        default:
 794            {
 795                GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0);
 796                GGML_ABORT("not implemented");
 797            }
 798    };
 799
 800    snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
 801    snprintf(name, 256, "%s_nsg=%d", base, nsg);
 802
 803    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 804    if (!res.pipeline) {
 805        ggml_metal_cv_t cv = ggml_metal_cv_init();
 806
 807        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
 808
 809        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 810
 811        ggml_metal_cv_free(cv);
 812    }
 813
 814    res.nr0  = nr0;
 815    res.nr1  = nr1;
 816    res.nsg  = nsg;
 817    res.smem = smem;
 818
 819    return res;
 820}
 821
 822ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
 823    char base[256];
 824    char name[256];
 825
 826    snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
 827    snprintf(name, 256, "%s_ne02=%d", base, ne02);
 828
 829    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 830    if (!res.pipeline) {
 831        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 832    }
 833
 834    res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
 835
 836    return res;
 837}
 838
 839ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
 840    char base[256];
 841    char name[256];
 842
 843    const ggml_type tsrc0 = op->src[0]->type;
 844    const ggml_type tsrc1 = op->src[1]->type;
 845
 846    const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
 847
 848    snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
 849    snprintf(name, 256, "%s_bci=%d", base, bc_inp);
 850
 851    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
 852    if (!res.pipeline) {
 853        ggml_metal_cv_t cv = ggml_metal_cv_init();
 854
 855        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
 856
 857        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 858
 859        ggml_metal_cv_free(cv);
 860    }
 861
 862    res.smem = 8192;
 863
 864    return res;
 865}
 866
 867ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
 868    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 869    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 870
 871    char base[256];
 872    char name[256];
 873
 874    int nsg = 0; // number of simdgroups
 875    int nr0 = 0; // number of src0 rows per simdgroup
 876    int nr1 = 1; // number of src1 rows per threadgroup
 877
 878    size_t smem = 0; // shared memory
 879
 880    const ggml_type tsrc0 = op->src[0]->type;
 881    const ggml_type tsrc1 = op->src[1]->type;
 882
 883    const char * suffix = "";
 884
 885        // use custom matrix x vector kernel
 886    switch (tsrc0) {
 887        case GGML_TYPE_F32:
 888        case GGML_TYPE_F16:
 889        case GGML_TYPE_BF16:
 890            {
 891                nsg = std::min(4, (ne00 + 127) / 128);
 892                nr0 = 2;
 893                nr1 = 1;
 894                smem = 32*sizeof(float)*nr0;
 895                suffix = ne00 % 4 == 0 ? "_4" : "";
 896            } break;
 897        case GGML_TYPE_Q4_0:
 898            {
 899                nsg = N_SG_Q4_0;
 900                nr0 = N_R0_Q4_0;
 901            } break;
 902        case GGML_TYPE_Q4_1:
 903            {
 904                nsg = N_SG_Q4_1;
 905                nr0 = N_R0_Q4_1;
 906            } break;
 907        case GGML_TYPE_Q5_0:
 908            {
 909                nsg = N_SG_Q5_0;
 910                nr0 = N_R0_Q5_0;
 911            } break;
 912        case GGML_TYPE_Q5_1:
 913            {
 914                nsg = N_SG_Q5_1;
 915                nr0 = N_R0_Q5_1;
 916            } break;
 917        case GGML_TYPE_Q8_0:
 918            {
 919                nsg = N_SG_Q8_0;
 920                nr0 = N_R0_Q8_0;
 921                smem = 32*sizeof(float)*N_R0_Q8_0;
 922            } break;
 923        case GGML_TYPE_MXFP4:
 924            {
 925                nsg = N_SG_MXFP4;
 926                nr0 = N_R0_MXFP4;
 927                smem = 32*sizeof(float);
 928            } break;
 929        case GGML_TYPE_Q2_K:
 930            {
 931                nsg = N_SG_Q2_K;
 932                nr0 = N_R0_Q2_K;
 933            } break;
 934        case GGML_TYPE_Q3_K:
 935            {
 936                nsg = N_SG_Q3_K;
 937                nr0 = N_R0_Q3_K;
 938            } break;
 939        case GGML_TYPE_Q4_K:
 940            {
 941                nsg = N_SG_Q4_K;
 942                nr0 = N_R0_Q4_K;
 943            } break;
 944        case GGML_TYPE_Q5_K:
 945            {
 946                nsg = N_SG_Q5_K;
 947                nr0 = N_R0_Q5_K;
 948            } break;
 949        case GGML_TYPE_Q6_K:
 950            {
 951                nsg = N_SG_Q6_K;
 952                nr0 = N_R0_Q6_K;
 953            } break;
 954        case GGML_TYPE_IQ2_XXS:
 955            {
 956                nsg = N_SG_IQ2_XXS;
 957                nr0 = N_R0_IQ2_XXS;
 958                smem = 256*8+128;
 959            } break;
 960        case GGML_TYPE_IQ2_XS:
 961            {
 962                nsg = N_SG_IQ2_XS;
 963                nr0 = N_R0_IQ2_XS;
 964                smem = 512*8+128;
 965            } break;
 966        case GGML_TYPE_IQ3_XXS:
 967            {
 968                nsg = N_SG_IQ3_XXS;
 969                nr0 = N_R0_IQ3_XXS;
 970                smem = 256*4+128;
 971            } break;
 972        case GGML_TYPE_IQ3_S:
 973            {
 974                nsg = N_SG_IQ3_S;
 975                nr0 = N_R0_IQ3_S;
 976                smem = 512*4;
 977            } break;
 978        case GGML_TYPE_IQ2_S:
 979            {
 980                nsg = N_SG_IQ2_S;
 981                nr0 = N_R0_IQ2_S;
 982            } break;
 983        case GGML_TYPE_IQ1_S:
 984            {
 985                nsg = N_SG_IQ1_S;
 986                nr0 = N_R0_IQ1_S;
 987            } break;
 988        case GGML_TYPE_IQ1_M:
 989            {
 990                nsg = N_SG_IQ1_M;
 991                nr0 = N_R0_IQ1_M;
 992            } break;
 993        case GGML_TYPE_IQ4_NL:
 994            {
 995                nsg = N_SG_IQ4_NL;
 996                nr0 = N_R0_IQ4_NL;
 997                smem = 32*sizeof(float);
 998            } break;
 999        case GGML_TYPE_IQ4_XS:
1000            {
1001                nsg = N_SG_IQ4_XS;
1002                nr0 = N_R0_IQ4_XS;
1003                smem = 32*sizeof(float);
1004            } break;
1005        default:
1006            {
1007                GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type);
1008                GGML_ABORT("not implemented");
1009            }
1010    };
1011
1012    snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
1013    snprintf(name, 256, "%s_nsg=%d", base, nsg);
1014
1015    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1016    if (!res.pipeline) {
1017        ggml_metal_cv_t cv = ggml_metal_cv_init();
1018
1019        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
1020
1021        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1022
1023        ggml_metal_cv_free(cv);
1024    }
1025
1026    res.nr0  = nr0;
1027    res.nr1  = nr1;
1028    res.nsg  = nsg;
1029    res.smem = smem;
1030
1031    return res;
1032}
1033
1034ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
1035    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
1036    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
1037    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
1038
1039    char base[256];
1040    char name[256];
1041
1042    snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
1043    snprintf(name, 256, "%s", base);
1044
1045    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1046    if (!res.pipeline) {
1047        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1048    }
1049
1050    res.smem = 32*(sizeof(float) + sizeof(int32_t));
1051
1052    return res;
1053}
1054
1055ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
1056    assert(op->op == GGML_OP_ARGSORT);
1057
1058    char base[256];
1059    char name[256];
1060
1061    ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1062
1063    const char * order_str = "undefined";
1064    switch (order) {
1065        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
1066        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1067        default: GGML_ABORT("fatal error");
1068    };
1069
1070    snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1071    snprintf(name, 256, "%s", base);
1072
1073    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1074    if (!res.pipeline) {
1075        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1076    }
1077
1078    return res;
1079}
1080
1081ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1082    assert(op->op == GGML_OP_ARGSORT);
1083
1084    char base[256];
1085    char name[256];
1086
1087    ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1088
1089    const char * order_str = "undefined";
1090    switch (order) {
1091        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
1092        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1093        default: GGML_ABORT("fatal error");
1094    };
1095
1096    snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1097    snprintf(name, 256, "%s", base);
1098
1099    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1100    if (!res.pipeline) {
1101        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1102    }
1103
1104    return res;
1105}
1106
1107// note: reuse the argsort kernel for top_k
1108ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1109    assert(op->op == GGML_OP_TOP_K);
1110
1111    char base[256];
1112    char name[256];
1113
1114    // note: the top_k kernel is always descending order
1115    ggml_sort_order order = GGML_SORT_ORDER_DESC;
1116
1117    const char * order_str = "undefined";
1118    switch (order) {
1119        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
1120        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1121        default: GGML_ABORT("fatal error");
1122    };
1123
1124    snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1125    snprintf(name, 256, "%s", base);
1126
1127    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1128    if (!res.pipeline) {
1129        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1130    }
1131
1132    return res;
1133}
1134
1135ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1136    assert(op->op == GGML_OP_TOP_K);
1137
1138    char base[256];
1139    char name[256];
1140
1141    ggml_sort_order order = GGML_SORT_ORDER_DESC;
1142
1143    const char * order_str = "undefined";
1144    switch (order) {
1145        case GGML_SORT_ORDER_ASC:  order_str = "asc";  break;
1146        case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1147        default: GGML_ABORT("fatal error");
1148    };
1149
1150    snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1151    snprintf(name, 256, "%s", base);
1152
1153    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1154    if (!res.pipeline) {
1155        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1156    }
1157
1158    return res;
1159}
1160
1161ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1162        ggml_metal_library_t lib,
1163        const struct ggml_tensor * op,
1164        bool    has_mask,
1165        int32_t ncpsg) {
1166    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1167    GGML_UNUSED(op);
1168
1169    char base[256];
1170    char name[256];
1171
1172    snprintf(base, 256, "kernel_%s",
1173            "flash_attn_ext_pad");
1174
1175    snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1176            base,
1177            has_mask,
1178            ncpsg);
1179
1180    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1181    if (!res.pipeline) {
1182        ggml_metal_cv_t cv = ggml_metal_cv_init();
1183
1184        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0);
1185        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1186        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);
1187        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);
1188
1189        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1190        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1191        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
1192        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
1193        //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1194        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1195
1196        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1197
1198        ggml_metal_cv_free(cv);
1199    }
1200
1201    return res;
1202}
1203
1204ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1205        ggml_metal_library_t lib,
1206        const struct ggml_tensor * op,
1207        int32_t nqptg,
1208        int32_t ncpsg) {
1209    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1210    GGML_UNUSED(op);
1211
1212    char base[256];
1213    char name[256];
1214
1215    snprintf(base, 256, "kernel_%s",
1216            "flash_attn_ext_blk");
1217
1218    snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1219            base,
1220            nqptg,
1221            ncpsg);
1222
1223    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1224    if (!res.pipeline) {
1225        ggml_metal_cv_t cv = ggml_metal_cv_init();
1226
1227        //ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_BLK + 0);
1228        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1229        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_BLK + 2);
1230        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_BLK + 3);
1231
1232        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1233        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1234        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_BLK + 22);
1235        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_BLK + 23);
1236        ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1237        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1238
1239        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1240
1241        ggml_metal_cv_free(cv);
1242    }
1243
1244    return res;
1245}
1246
1247ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
1248        ggml_metal_library_t lib,
1249        const ggml_tensor * op,
1250        bool    has_mask,
1251        bool    has_sinks,
1252        bool    has_bias,
1253        bool    has_scap,
1254        bool    has_kvpad,
1255        int32_t nsg) {
1256    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1257
1258    char base[256];
1259    char name[256];
1260
1261    const int32_t dk = (int32_t) op->src[1]->ne[0];
1262    const int32_t dv = (int32_t) op->src[2]->ne[0];
1263
1264    const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1265    const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1266
1267    // do bounds checks for the mask?
1268    const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1269
1270    snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1271            "flash_attn_ext",
1272            ggml_type_name(op->src[1]->type),
1273            dk,
1274            dv);
1275
1276    snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
1277            base,
1278            has_mask,
1279            has_sinks,
1280            has_bias,
1281            has_scap,
1282            has_kvpad,
1283            bc_mask,
1284            ns10,
1285            ns20,
1286            nsg);
1287
1288    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1289    if (!res.pipeline) {
1290        ggml_metal_cv_t cv = ggml_metal_cv_init();
1291
1292        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT + 0);
1293        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1294        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT + 2);
1295        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT + 3);
1296        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1297
1298        ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1299
1300        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1301        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1302        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT + 22);
1303
1304        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1305
1306        ggml_metal_cv_free(cv);
1307    }
1308
1309    return res;
1310}
1311
1312ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1313        ggml_metal_library_t lib,
1314        const ggml_tensor * op,
1315        bool    has_mask,
1316        bool    has_sinks,
1317        bool    has_bias,
1318        bool    has_scap,
1319        bool    has_kvpad,
1320        int32_t nsg,
1321        int32_t nwg) {
1322    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1323
1324    char base[256];
1325    char name[256];
1326
1327    const int32_t dk = (int32_t) op->src[1]->ne[0];
1328    const int32_t dv = (int32_t) op->src[2]->ne[0];
1329
1330    const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1331    const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1332
1333    snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1334            "flash_attn_ext_vec",
1335            ggml_type_name(op->src[1]->type),
1336            dk,
1337            dv);
1338
1339    snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1340            base,
1341            has_mask,
1342            has_sinks,
1343            has_bias,
1344            has_scap,
1345            has_kvpad,
1346            ns10,
1347            ns20,
1348            nsg, nwg);
1349
1350    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1351    if (!res.pipeline) {
1352        ggml_metal_cv_t cv = ggml_metal_cv_init();
1353
1354        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_VEC + 0);
1355        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1356        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2);
1357        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3);
1358        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1359
1360        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1361        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1362        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_VEC + 22);
1363        ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_VEC + 23);
1364
1365        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1366
1367        ggml_metal_cv_free(cv);
1368    }
1369
1370    return res;
1371}
1372
1373ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1374        ggml_metal_library_t lib,
1375        const ggml_tensor * op,
1376        int32_t dv,
1377        int32_t nwg) {
1378    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1379
1380    char base[256];
1381    char name[256];
1382
1383    snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1384    snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1385
1386    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1387    if (!res.pipeline) {
1388        ggml_metal_cv_t cv = ggml_metal_cv_init();
1389
1390        ggml_metal_cv_set_int32(cv, dv,  FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1391        ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1392
1393        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1394
1395        ggml_metal_cv_free(cv);
1396    }
1397
1398    return res;
1399
1400    GGML_UNUSED(op);
1401}
1402
1403ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1404    char base[256];
1405    char name[256];
1406
1407    int op_num = -1;
1408
1409    switch (op->op) {
1410        case GGML_OP_ADD: op_num = 0; break;
1411        case GGML_OP_SUB: op_num = 1; break;
1412        case GGML_OP_MUL: op_num = 2; break;
1413        case GGML_OP_DIV: op_num = 3; break;
1414        default: GGML_ABORT("fatal error");
1415    };
1416
1417    const char * t0_str = ggml_type_name(op->src[0]->type);
1418    const char * t1_str = ggml_type_name(op->src[1]->type);
1419    const char * t_str  = ggml_type_name(op->type);
1420
1421    const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
1422
1423    const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
1424
1425    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
1426    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
1427
1428    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1429    if (!res.pipeline) {
1430        ggml_metal_cv_t cv = ggml_metal_cv_init();
1431
1432        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1433        ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
1434        ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);
1435
1436        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1437
1438        ggml_metal_cv_free(cv);
1439    }
1440
1441    res.c4  = is_c4;
1442    res.cnt = is_rb;
1443
1444    return res;
1445}
1446
1447ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
1448    char base[256];
1449    char name[256];
1450
1451    int op_num = -1;
1452
1453    switch (op) {
1454        case GGML_OP_ADD: op_num = 0; break;
1455        case GGML_OP_SUB: op_num = 1; break;
1456        case GGML_OP_MUL: op_num = 2; break;
1457        case GGML_OP_DIV: op_num = 3; break;
1458        default: GGML_ABORT("fatal error");
1459    };
1460
1461    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
1462    snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
1463
1464    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1465    if (!res.pipeline) {
1466        ggml_metal_cv_t cv = ggml_metal_cv_init();
1467
1468        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1469        ggml_metal_cv_set_int16(cv, 1,      FC_BIN + 1);
1470        ggml_metal_cv_set_bool (cv, false,  FC_BIN + 2);
1471
1472        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1473
1474        ggml_metal_cv_free(cv);
1475    }
1476
1477    return res;
1478}
1479
1480ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1481    assert(op->op == GGML_OP_L2_NORM);
1482
1483    char base[256];
1484    char name[256];
1485
1486    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
1487
1488    const char * t0_str = ggml_type_name(op->src[0]->type);
1489    const char * t_str  = ggml_type_name(op->type);
1490
1491    snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
1492    snprintf(name, 256, "%s", base);
1493
1494    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1495    if (!res.pipeline) {
1496        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1497    }
1498
1499    res.c4   = is_c4;
1500    res.smem = 32*sizeof(float);
1501
1502    return res;
1503}
1504
1505ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1506    assert(op->op == GGML_OP_GROUP_NORM);
1507
1508    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1509
1510    char base[256];
1511    char name[256];
1512
1513    snprintf(base, 256, "kernel_group_norm_f32");
1514    snprintf(name, 256, "%s", base);
1515
1516    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1517    if (!res.pipeline) {
1518        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1519    }
1520
1521    res.smem = 32*sizeof(float);
1522
1523    return res;
1524}
1525
1526ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1527    assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
1528
1529    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1530
1531    char base[256];
1532    char name[256];
1533
1534    const char * suffix = "";
1535    if (op->ne[0] % 4 == 0) {
1536        suffix = "_4";
1537    }
1538
1539    switch (op->op) {
1540        case GGML_OP_NORM:
1541            switch (n_fuse) {
1542                case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix);         break;
1543                case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix);     break;
1544                case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
1545                default: GGML_ABORT("fatal error");
1546            } break;
1547        case GGML_OP_RMS_NORM:
1548            switch (n_fuse) {
1549                case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix);         break;
1550                case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix);     break;
1551                case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
1552                default: GGML_ABORT("fatal error");
1553            } break;
1554        default: GGML_ABORT("fatal error");
1555    }
1556
1557    snprintf(name, 256, "%s", base);
1558
1559    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1560    if (!res.pipeline) {
1561        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1562    }
1563
1564    res.smem = 32*sizeof(float);
1565
1566    return res;
1567}
1568
1569ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1570    assert(op->op == GGML_OP_ROPE);
1571
1572    char base[256];
1573    char name[256];
1574
1575    const int mode = ((const int32_t *) op->op_params)[2];
1576
1577    const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX;
1578    const bool is_mrope  = mode & GGML_ROPE_TYPE_MROPE;
1579    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
1580    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1581
1582    if (is_neox) {
1583        snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1584    } else if ((is_mrope || is_imrope) && !is_vision) {
1585        GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1586        snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
1587    } else if (is_vision) {
1588        GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1589        snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type));
1590    } else {
1591        snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
1592    }
1593
1594    snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1595
1596    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1597    if (!res.pipeline) {
1598        ggml_metal_cv_t cv = ggml_metal_cv_init();
1599
1600        ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1601
1602        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1603
1604        ggml_metal_cv_free(cv);
1605    }
1606
1607    return res;
1608}
1609
1610ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1611    assert(op->op == GGML_OP_IM2COL);
1612
1613    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1614    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1615    GGML_ASSERT(op->type         == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1616
1617    char base[256];
1618    char name[256];
1619
1620    snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1621    snprintf(name, 256, "%s", base);
1622
1623    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1624    if (!res.pipeline) {
1625        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1626    }
1627
1628    return res;
1629}
1630
1631ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1632    assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
1633
1634    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1635    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1636    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1637    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1638    GGML_ASSERT(op->type         == GGML_TYPE_F32);
1639
1640    char base[256];
1641    char name[256];
1642
1643    snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1644    snprintf(name, 256, "%s", base);
1645
1646    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1647    if (!res.pipeline) {
1648        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1649    }
1650
1651    return res;
1652}
1653
1654ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1655    assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1656
1657    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1658    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1659    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1660    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1661    GGML_ASSERT(op->type         == GGML_TYPE_F32);
1662
1663    char base[256];
1664    char name[256];
1665
1666    snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1667    snprintf(name, 256, "%s", base);
1668
1669    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1670    if (!res.pipeline) {
1671        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1672    }
1673
1674    return res;
1675}
1676
1677ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1678    assert(op->op == GGML_OP_CONV_2D);
1679
1680    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1681    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1682    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1683    GGML_ASSERT(op->type         == GGML_TYPE_F32);
1684
1685    char base[256];
1686    char name[256];
1687
1688    snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1689    snprintf(name, 256, "%s", base);
1690
1691    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1692    if (!res.pipeline) {
1693        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1694    }
1695
1696    return res;
1697}
1698
1699ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1700    assert(op->op == GGML_OP_UPSCALE);
1701
1702    char base[256];
1703    char name[256];
1704
1705    snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
1706    snprintf(name, 256, "%s", base);
1707
1708    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1709    if (!res.pipeline) {
1710        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1711    }
1712
1713    return res;
1714}
1715
1716ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1717    assert(op->op == GGML_OP_PAD);
1718
1719    char base[256];
1720    char name[256];
1721
1722    snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1723    snprintf(name, 256, "%s", base);
1724
1725    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1726    if (res.pipeline) {
1727        return res;
1728    }
1729
1730    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1731
1732    return res;
1733}
1734
1735ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1736    assert(op->op == GGML_OP_PAD_REFLECT_1D);
1737
1738    char base[256];
1739    char name[256];
1740
1741    snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
1742    snprintf(name, 256, "%s", base);
1743
1744    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1745    if (!res.pipeline) {
1746        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1747    }
1748
1749    return res;
1750}
1751
1752ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1753    assert(op->op == GGML_OP_ARANGE);
1754
1755    char base[256];
1756    char name[256];
1757
1758    snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
1759    snprintf(name, 256, "%s", base);
1760
1761    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1762    if (!res.pipeline) {
1763        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1764    }
1765
1766    return res;
1767}
1768
1769ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1770    assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
1771
1772    char base[256];
1773    char name[256];
1774
1775    snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
1776    snprintf(name, 256, "%s", base);
1777
1778    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1779    if (!res.pipeline) {
1780        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1781    }
1782
1783    return res;
1784}
1785
1786ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1787    assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1788
1789    char base[256];
1790    char name[256];
1791
1792    snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1793    snprintf(name, 256, "%s", base);
1794
1795    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1796    if (!res.pipeline) {
1797        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1798    }
1799
1800    return res;
1801}
1802
1803ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1804    assert(op->op == GGML_OP_OPT_STEP_SGD);
1805
1806    char base[256];
1807    char name[256];
1808
1809    snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1810    snprintf(name, 256, "%s", base);
1811
1812    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1813    if (!res.pipeline) {
1814        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1815    }
1816
1817    return res;
1818}
1819
1820ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor *  op) {
1821    GGML_ASSERT(op->type == GGML_TYPE_I64);
1822
1823    char base[256];
1824    char name[256];
1825
1826    snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
1827    snprintf(name, 256, "%s", base);
1828
1829    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1830    if (!res.pipeline) {
1831        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1832    }
1833
1834    return res;
1835}
1836
1837ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor *  op) {
1838    assert(op->op == GGML_OP_COUNT_EQUAL);
1839
1840    GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1841
1842    GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1843    GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
1844    GGML_ASSERT(op->type == GGML_TYPE_I64);
1845
1846    // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1847    GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
1848
1849    char base[256];
1850    char name[256];
1851
1852    int nsg = 1;
1853    while (32*nsg < ne00 && nsg < 32) {
1854        nsg *= 2;
1855    }
1856
1857    snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
1858    snprintf(name, 256, "%s_nsg=%d", base, nsg);
1859
1860    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1861    if (!res.pipeline) {
1862        ggml_metal_cv_t cv = ggml_metal_cv_init();
1863
1864        ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1865
1866        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1867
1868        ggml_metal_cv_free(cv);
1869    }
1870
1871    res.smem = 32 * sizeof(int32_t);
1872    res.nsg  = nsg;
1873
1874    return res;
1875}