1#include "ggml-metal-ops.h"
   2
   3#include "ggml.h"
   4#include "ggml-impl.h"
   5#include "ggml-backend-impl.h"
   6
   7#include "ggml-metal-impl.h"
   8#include "ggml-metal-common.h"
   9#include "ggml-metal-device.h"
  10
  11#include <cassert>
  12#include <algorithm>
  13#include <limits>
  14#include <cmath>
  15
  16static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
  17    if (!t) {
  18        return { nullptr, 0 };
  19    }
  20
  21    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
  22
  23    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;
  24
  25    return ggml_metal_buffer_get_id(ctx, t);
  26}
  27
  28struct ggml_metal_op {
  29    ggml_metal_op(
  30        ggml_metal_device_t dev,
  31        ggml_metal_cmd_buf_t cmd_buf,
  32        ggml_cgraph * gf,
  33        int  idx_start,
  34        int  idx_end,
  35        bool use_fusion,
  36        bool use_concurrency,
  37        bool use_capture,
  38        int  debug_graph,
  39        int  debug_fusion) {
  40        this->dev             = dev;
  41        this->lib             = ggml_metal_device_get_library(dev);
  42        this->enc             = ggml_metal_encoder_init(cmd_buf, use_concurrency);
  43        this->mem_ranges      = ggml_mem_ranges_init(debug_graph);
  44        this->idx_start       = idx_start;
  45        this->idx_end         = idx_end;
  46        this->use_fusion      = use_fusion;
  47        this->use_concurrency = use_concurrency;
  48        this->use_capture     = use_capture;
  49        this->debug_graph     = debug_graph;
  50        this->debug_fusion    = debug_fusion;
  51        this->gf              = gf;
  52
  53        idxs.reserve(gf->n_nodes);
  54
  55        // filter empty nodes
  56        // TODO: this can be removed when the allocator starts filtering them earlier
  57        //       https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
  58        for (int i = idx_start; i < idx_end; i++) {
  59            if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
  60                idxs.push_back(i);
  61            }
  62        }
  63    }
  64
  65    ~ggml_metal_op() {
  66        ggml_metal_encoder_end_encoding(this->enc);
  67        ggml_metal_encoder_free(this->enc);
  68        ggml_mem_ranges_free(this->mem_ranges);
  69    }
  70
  71    int n_nodes() const {
  72        return idxs.size();
  73    }
  74
  75    ggml_tensor * node(int i) const {
  76        assert(i >= 0 && i < (int) idxs.size());
  77        return ggml_graph_node(gf, idxs[i]);
  78    }
  79
  80    bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
  81        assert(use_fusion);
  82        assert(i0 >= 0 && i0 < n_nodes());
  83
  84        if (i0 + n_ops > n_nodes()) {
  85            return false;
  86        }
  87
  88        return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
  89    }
  90
  91    ggml_metal_device_t  dev;
  92    ggml_metal_library_t lib;
  93    ggml_metal_encoder_t enc;
  94    ggml_mem_ranges_t    mem_ranges;
  95
  96    bool use_fusion;
  97    bool use_concurrency;
  98    bool use_capture;
  99
 100    int debug_graph;
 101    int debug_fusion;
 102
 103private:
 104    ggml_cgraph * gf;
 105
 106    int idx_start;
 107    int idx_end;
 108
 109    // non-empty node indices
 110    std::vector<int> idxs;
 111};
 112
 113ggml_metal_op_t ggml_metal_op_init(
 114        ggml_metal_device_t dev,
 115        ggml_metal_cmd_buf_t cmd_buf,
 116        ggml_cgraph * gf,
 117        int idx_start,
 118        int idx_end,
 119        bool use_fusion,
 120        bool use_concurrency,
 121        bool use_capture,
 122        int debug_graph,
 123        int debug_fusion) {
 124    ggml_metal_op_t res = new ggml_metal_op(
 125        dev,
 126        cmd_buf,
 127        gf,
 128        idx_start,
 129        idx_end,
 130        use_fusion,
 131        use_concurrency,
 132        use_capture,
 133        debug_graph,
 134        debug_fusion);
 135
 136    return res;
 137}
 138
 139void ggml_metal_op_free(ggml_metal_op_t ctx) {
 140    delete ctx;
 141}
 142
 143int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
 144    return ctx->n_nodes();
 145}
 146
 147static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
 148    if (!ctx->mem_ranges) {
 149        return true;
 150    }
 151
 152    ggml_metal_encoder_memory_barrier(ctx->enc);
 153
 154    ggml_mem_ranges_reset(ctx->mem_ranges);
 155
 156    return true;
 157}
 158
 159static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {
 160    if (!ctx->mem_ranges) {
 161        return false;
 162    }
 163
 164    return ggml_mem_ranges_check(ctx->mem_ranges, node);
 165}
 166
 167static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {
 168    if (!ctx->mem_ranges) {
 169        return true;
 170    }
 171
 172    return ggml_mem_ranges_add(ctx->mem_ranges, node);
 173}
 174
 175static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
 176    struct ggml_tensor * node = ctx->node(idx);
 177
 178    //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
 179
 180    if (ggml_is_empty(node)) {
 181        return 1;
 182    }
 183
 184    switch (node->op) {
 185        case GGML_OP_NONE:
 186        case GGML_OP_RESHAPE:
 187        case GGML_OP_VIEW:
 188        case GGML_OP_TRANSPOSE:
 189        case GGML_OP_PERMUTE:
 190            {
 191                // noop -> next node
 192                if (ctx->debug_graph > 0) {
 193                    GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
 194                }
 195            } return 1;
 196        default:
 197            {
 198            } break;
 199    }
 200
 201    if (!ggml_metal_device_supports_op(ctx->dev, node)) {
 202        GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node));
 203        GGML_ABORT("unsupported op");
 204    }
 205
 206    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
 207        return 1;
 208    }
 209
 210    int n_fuse = 1;
 211
 212    // check if the current node can run concurrently with other nodes before it
 213    // the condition is that:
 214    //  - the current node cannot write to any previous src or dst ranges
 215    //  - the current node cannot read from any previous dst ranges
 216    //
 217    // if the condition is not satisfied, we put a memory barrier and clear all ranges
 218    // otherwise, we add the new ranges to the encoding context and process the node concurrently
 219    //
 220    {
 221        const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);
 222
 223        if (!is_concurrent) {
 224            ggml_metal_op_concurrency_reset(ctx);
 225        }
 226
 227        if (ctx->debug_graph > 0) {
 228            GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
 229        }
 230        if (ctx->debug_graph > 1) {
 231            GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
 232            GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
 233            GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
 234            GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
 235            GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
 236            GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
 237            GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
 238            GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
 239            GGML_TENSOR_LOCALS( int64_t, ne,  node,         ne);
 240            GGML_TENSOR_LOCALS(uint64_t, nb,  node,         nb);
 241
 242            if (node->src[0]) {
 243                GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
 244                        ggml_is_contiguous(node->src[0]), node->src[0]->name);
 245            }
 246            if (node->src[1]) {
 247                GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
 248                        ggml_is_contiguous(node->src[1]), node->src[1]->name);
 249            }
 250            if (node->src[2]) {
 251                GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
 252                        ggml_is_contiguous(node->src[2]), node->src[2]->name);
 253            }
 254            if (node->src[3]) {
 255                GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
 256                        ggml_is_contiguous(node->src[3]), node->src[3]->name);
 257            }
 258            if (node) {
 259                GGML_LOG_DEBUG("%s: node  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
 260                        node->name);
 261            }
 262        }
 263    }
 264
 265    switch (node->op) {
 266        case GGML_OP_CONCAT:
 267            {
 268                n_fuse = ggml_metal_op_concat(ctx, idx);
 269            } break;
 270        case GGML_OP_ADD:
 271        case GGML_OP_SUB:
 272        case GGML_OP_MUL:
 273        case GGML_OP_DIV:
 274            {
 275                n_fuse = ggml_metal_op_bin(ctx, idx);
 276            } break;
 277        case GGML_OP_ADD_ID:
 278            {
 279                n_fuse = ggml_metal_op_add_id(ctx, idx);
 280            } break;
 281        case GGML_OP_REPEAT:
 282            {
 283                n_fuse = ggml_metal_op_repeat(ctx, idx);
 284            } break;
 285        case GGML_OP_ACC:
 286            {
 287                n_fuse = ggml_metal_op_acc(ctx, idx);
 288            } break;
 289        case GGML_OP_SCALE:
 290        case GGML_OP_FILL:
 291        case GGML_OP_CLAMP:
 292        case GGML_OP_LEAKY_RELU:
 293        case GGML_OP_SQR:
 294        case GGML_OP_SQRT:
 295        case GGML_OP_SIN:
 296        case GGML_OP_COS:
 297        case GGML_OP_LOG:
 298        case GGML_OP_UNARY:
 299            {
 300                n_fuse = ggml_metal_op_unary(ctx, idx);
 301            } break;
 302        case GGML_OP_GLU:
 303            {
 304                n_fuse = ggml_metal_op_glu(ctx, idx);
 305            } break;
 306        case GGML_OP_SUM:
 307            {
 308                n_fuse = ggml_metal_op_sum(ctx, idx);
 309            } break;
 310        case GGML_OP_SUM_ROWS:
 311        case GGML_OP_MEAN:
 312            {
 313                n_fuse = ggml_metal_op_sum_rows(ctx, idx);
 314            } break;
 315        case GGML_OP_CUMSUM:
 316            {
 317                n_fuse = ggml_metal_op_cumsum(ctx, idx);
 318            } break;
 319        case GGML_OP_SOFT_MAX:
 320            {
 321                n_fuse = ggml_metal_op_soft_max(ctx, idx);
 322            } break;
 323        case GGML_OP_SSM_CONV:
 324            {
 325                n_fuse = ggml_metal_op_ssm_conv(ctx, idx);
 326            } break;
 327        case GGML_OP_SSM_SCAN:
 328            {
 329                n_fuse = ggml_metal_op_ssm_scan(ctx, idx);
 330            } break;
 331        case GGML_OP_RWKV_WKV6:
 332        case GGML_OP_RWKV_WKV7:
 333            {
 334                n_fuse = ggml_metal_op_rwkv(ctx, idx);
 335            } break;
 336        case GGML_OP_SOLVE_TRI:
 337            {
 338                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
 339            } break;
 340        case GGML_OP_MUL_MAT:
 341            {
 342                n_fuse = ggml_metal_op_mul_mat(ctx, idx);
 343            } break;
 344        case GGML_OP_MUL_MAT_ID:
 345            {
 346                n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);
 347            } break;
 348        case GGML_OP_GET_ROWS:
 349            {
 350                n_fuse = ggml_metal_op_get_rows(ctx, idx);
 351            } break;
 352        case GGML_OP_SET_ROWS:
 353            {
 354                n_fuse = ggml_metal_op_set_rows(ctx, idx);
 355            } break;
 356        case GGML_OP_DIAG:
 357            {
 358                n_fuse = ggml_metal_op_diag(ctx, idx);
 359            } break;
 360        case GGML_OP_L2_NORM:
 361            {
 362                n_fuse = ggml_metal_op_l2_norm(ctx, idx);
 363            } break;
 364        case GGML_OP_GROUP_NORM:
 365            {
 366                n_fuse = ggml_metal_op_group_norm(ctx, idx);
 367            } break;
 368        case GGML_OP_NORM:
 369        case GGML_OP_RMS_NORM:
 370            {
 371                n_fuse = ggml_metal_op_norm(ctx, idx);
 372            } break;
 373        case GGML_OP_ROPE:
 374            {
 375                n_fuse = ggml_metal_op_rope(ctx, idx);
 376            } break;
 377        case GGML_OP_IM2COL:
 378            {
 379                n_fuse = ggml_metal_op_im2col(ctx, idx);
 380            } break;
 381        case GGML_OP_CONV_2D:
 382            {
 383                n_fuse = ggml_metal_op_conv_2d(ctx, idx);
 384            } break;
 385        case GGML_OP_CONV_TRANSPOSE_1D:
 386            {
 387                n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
 388            } break;
 389        case GGML_OP_CONV_TRANSPOSE_2D:
 390            {
 391                n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
 392            } break;
 393        case GGML_OP_UPSCALE:
 394            {
 395                n_fuse = ggml_metal_op_upscale(ctx, idx);
 396            } break;
 397        case GGML_OP_PAD:
 398            {
 399                n_fuse = ggml_metal_op_pad(ctx, idx);
 400            } break;
 401        case GGML_OP_PAD_REFLECT_1D:
 402            {
 403                n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
 404            } break;
 405        case GGML_OP_ARANGE:
 406            {
 407                n_fuse = ggml_metal_op_arange(ctx, idx);
 408            } break;
 409        case GGML_OP_TIMESTEP_EMBEDDING:
 410            {
 411                n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);
 412            } break;
 413        case GGML_OP_ARGSORT:
 414            {
 415                n_fuse = ggml_metal_op_argsort(ctx, idx);
 416            } break;
 417        case GGML_OP_TOP_K:
 418            {
 419                n_fuse = ggml_metal_op_top_k(ctx, idx);
 420            } break;
 421        case GGML_OP_TRI:
 422            {
 423                n_fuse = ggml_metal_op_tri(ctx, idx);
 424            } break;
 425        case GGML_OP_FLASH_ATTN_EXT:
 426            {
 427                n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
 428            } break;
 429        case GGML_OP_DUP:
 430        case GGML_OP_CPY:
 431        case GGML_OP_CONT:
 432            {
 433                n_fuse = ggml_metal_op_cpy(ctx, idx);
 434            } break;
 435        case GGML_OP_POOL_1D:
 436            {
 437                n_fuse = ggml_metal_op_pool_1d(ctx, idx);
 438            } break;
 439        case GGML_OP_POOL_2D:
 440            {
 441                n_fuse = ggml_metal_op_pool_2d(ctx, idx);
 442            } break;
 443        case GGML_OP_ARGMAX:
 444            {
 445                n_fuse = ggml_metal_op_argmax(ctx, idx);
 446            } break;
 447        case GGML_OP_OPT_STEP_ADAMW:
 448            {
 449                n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
 450            } break;
 451        case GGML_OP_OPT_STEP_SGD:
 452            {
 453                n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
 454            } break;
 455        case GGML_OP_COUNT_EQUAL:
 456            {
 457                n_fuse = ggml_metal_op_count_equal(ctx, idx);
 458            } break;
 459        default:
 460            {
 461                GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
 462                GGML_ABORT("fatal error");
 463            }
 464    }
 465
 466    if (ctx->debug_graph > 0) {
 467        if (n_fuse > 1) {
 468            GGML_LOG_DEBUG("%s:               fuse %d ops\n", __func__, n_fuse);
 469        }
 470    }
 471
 472    // update the mem ranges in the encoding context
 473    for (int i = 0; i < n_fuse; ++i) {
 474        if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
 475            ggml_metal_op_concurrency_reset(ctx);
 476        }
 477    }
 478
 479    return n_fuse;
 480}
 481
 482int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
 483    if (ctx->use_capture) {
 484        ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
 485    }
 486
 487    int res = ggml_metal_op_encode_impl(ctx, idx);
 488    if (idx + res > ctx->n_nodes()) {
 489        GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
 490                "https://github.com/ggml-org/llama.cpp/pull/14849");
 491    }
 492
 493    if (ctx->use_capture) {
 494        ggml_metal_encoder_debug_group_pop(ctx->enc);
 495    }
 496
 497    return res;
 498}
 499
 500int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
 501    ggml_tensor * op = ctx->node(idx);
 502
 503    ggml_metal_library_t lib = ctx->lib;
 504    ggml_metal_encoder_t enc = ctx->enc;
 505
 506    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 507    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 508    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 509    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
 510    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 511    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 512
 513    const int32_t dim = ((const int32_t *) op->op_params)[0];
 514
 515    ggml_metal_kargs_concat args = {
 516        /*.ne00 =*/ ne00,
 517        /*.ne01 =*/ ne01,
 518        /*.ne02 =*/ ne02,
 519        /*.ne03 =*/ ne03,
 520        /*.nb00 =*/ nb00,
 521        /*.nb01 =*/ nb01,
 522        /*.nb02 =*/ nb02,
 523        /*.nb03 =*/ nb03,
 524        /*.ne10 =*/ ne10,
 525        /*.ne11 =*/ ne11,
 526        /*.ne12 =*/ ne12,
 527        /*.ne13 =*/ ne13,
 528        /*.nb10 =*/ nb10,
 529        /*.nb11 =*/ nb11,
 530        /*.nb12 =*/ nb12,
 531        /*.nb13 =*/ nb13,
 532        /*.ne0  =*/ ne0,
 533        /*.ne1  =*/ ne1,
 534        /*.ne2  =*/ ne2,
 535        /*.ne3  =*/ ne3,
 536        /*.nb0  =*/ nb0,
 537        /*.nb1  =*/ nb1,
 538        /*.nb2  =*/ nb2,
 539        /*.nb3  =*/ nb3,
 540        /*.dim  =*/ dim,
 541    };
 542
 543    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
 544
 545    ggml_metal_encoder_set_pipeline(enc, pipeline);
 546    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 547    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 548    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
 549    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 550
 551    const int nth = std::min(1024, ne0);
 552
 553    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
 554
 555    return 1;
 556}
 557
 558int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
 559    ggml_tensor * op = ctx->node(idx);
 560
 561    ggml_metal_library_t lib = ctx->lib;
 562    ggml_metal_encoder_t enc = ctx->enc;
 563
 564    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 565    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 566    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 567    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 568
 569    auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
 570
 571    ggml_metal_kargs_repeat args = {
 572        /*.ne00 =*/ ne00,
 573        /*.ne01 =*/ ne01,
 574        /*.ne02 =*/ ne02,
 575        /*.ne03 =*/ ne03,
 576        /*.nb00 =*/ nb00,
 577        /*.nb01 =*/ nb01,
 578        /*.nb02 =*/ nb02,
 579        /*.nb03 =*/ nb03,
 580        /*.ne0  =*/ ne0,
 581        /*.ne1  =*/ ne1,
 582        /*.ne2  =*/ ne2,
 583        /*.ne3  =*/ ne3,
 584        /*.nb0  =*/ nb0,
 585        /*.nb1  =*/ nb1,
 586        /*.nb2  =*/ nb2,
 587        /*.nb3  =*/ nb3,
 588    };
 589
 590    ggml_metal_encoder_set_pipeline(enc, pipeline);
 591    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 592    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 593    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 594
 595    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
 596
 597    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
 598
 599    return 1;
 600}
 601
 602int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
 603    ggml_tensor * op = ctx->node(idx);
 604
 605    ggml_metal_library_t lib = ctx->lib;
 606    ggml_metal_encoder_t enc = ctx->enc;
 607
 608    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 609    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 610    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 611    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
 612    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 613    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 614
 615    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
 616    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
 617    GGML_ASSERT(op->type         == GGML_TYPE_F32);
 618
 619    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 620    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
 621
 622    const size_t pnb1 = ((const int32_t *) op->op_params)[0];
 623    const size_t pnb2 = ((const int32_t *) op->op_params)[1];
 624    const size_t pnb3 = ((const int32_t *) op->op_params)[2];
 625    const size_t offs = ((const int32_t *) op->op_params)[3];
 626
 627    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
 628
 629    if (!inplace) {
 630        // run a separete kernel to cpy src->dst
 631        // not sure how to avoid this
 632        // TODO: make a simpler cpy_bytes kernel
 633
 634        //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
 635        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
 636
 637        ggml_metal_kargs_cpy args = {
 638            /*.nk0  =*/ ne00,
 639            /*.ne00 =*/ ne00,
 640            /*.ne01 =*/ ne01,
 641            /*.ne02 =*/ ne02,
 642            /*.ne03 =*/ ne03,
 643            /*.nb00 =*/ nb00,
 644            /*.nb01 =*/ nb01,
 645            /*.nb02 =*/ nb02,
 646            /*.nb03 =*/ nb03,
 647            /*.ne0  =*/ ne0,
 648            /*.ne1  =*/ ne1,
 649            /*.ne2  =*/ ne2,
 650            /*.ne3  =*/ ne3,
 651            /*.nb0  =*/ nb0,
 652            /*.nb1  =*/ nb1,
 653            /*.nb2  =*/ nb2,
 654            /*.nb3  =*/ nb3,
 655        };
 656
 657        ggml_metal_encoder_set_pipeline(enc, pipeline);
 658        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 659        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 660        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 661
 662        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
 663
 664        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 665
 666        ggml_metal_op_concurrency_reset(ctx);
 667    }
 668
 669    ggml_metal_kargs_bin args = {
 670        /*.ne00 =*/ ne00,
 671        /*.ne01 =*/ ne01,
 672        /*.ne02 =*/ ne02,
 673        /*.ne03 =*/ ne03,
 674        /*.nb00 =*/ nb00,
 675        /*.nb01 =*/ pnb1,
 676        /*.nb02 =*/ pnb2,
 677        /*.nb03 =*/ pnb3,
 678        /*.ne10 =*/ ne10,
 679        /*.ne11 =*/ ne11,
 680        /*.ne12 =*/ ne12,
 681        /*.ne13 =*/ ne13,
 682        /*.nb10 =*/ nb10,
 683        /*.nb11 =*/ nb11,
 684        /*.nb12 =*/ nb12,
 685        /*.nb13 =*/ nb13,
 686        /*.ne0  =*/ ne0,
 687        /*.ne1  =*/ ne1,
 688        /*.ne2  =*/ ne2,
 689        /*.ne3  =*/ ne3,
 690        /*.nb0  =*/ nb0,
 691        /*.nb1  =*/ pnb1,
 692        /*.nb2  =*/ pnb2,
 693        /*.nb3  =*/ pnb3,
 694        /*.offs =*/ offs,
 695        /*.o1   =*/ { 0 },
 696    };
 697
 698    auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
 699
 700    ggml_metal_encoder_set_pipeline(enc, pipeline);
 701    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 702    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 703    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
 704    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 705
 706    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
 707
 708    ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
 709
 710    return 1;
 711}
 712
 713int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
 714    ggml_tensor * op = ctx->node(idx);
 715
 716    ggml_metal_library_t lib = ctx->lib;
 717    ggml_metal_encoder_t enc = ctx->enc;
 718
 719    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 720    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 721    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 722    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 723
 724    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 725
 726    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
 727    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
 728
 729    ggml_metal_kargs_unary args = {
 730        /*.ne00  =*/ ne00,
 731        /*.ne01  =*/ ne01,
 732        /*.ne02  =*/ ne02,
 733        /*.ne03  =*/ ne03,
 734        /*.nb00  =*/ nb00,
 735        /*.nb01  =*/ nb01,
 736        /*.nb02  =*/ nb02,
 737        /*.nb03  =*/ nb03,
 738        /*.ne0   =*/ ne0,
 739        /*.ne1   =*/ ne1,
 740        /*.ne2   =*/ ne2,
 741        /*.ne3   =*/ ne3,
 742        /*.nb0   =*/ nb0,
 743        /*.nb1   =*/ nb1,
 744        /*.nb2   =*/ nb2,
 745        /*.nb3   =*/ nb3,
 746        /*.slope =*/ 0.0,
 747        /*.scale =*/ 0.0,
 748        /*.bias  =*/ 0.0,
 749        /*.val   =*/ 0.0,
 750        /*.min   =*/ 0.0,
 751        /*.max   =*/ 0.0,
 752    };
 753
 754    if (op->op == GGML_OP_LEAKY_RELU) {
 755        args.slope = ggml_get_op_params_f32(op, 0);
 756    }
 757
 758    if (op->op == GGML_OP_SCALE) {
 759        args.scale = ggml_get_op_params_f32(op, 0);
 760        args.bias  = ggml_get_op_params_f32(op, 1);
 761    }
 762
 763    if (op->op == GGML_OP_FILL) {
 764        args.val = ggml_get_op_params_f32(op, 0);
 765    }
 766
 767    if (op->op == GGML_OP_CLAMP) {
 768        args.min = ggml_get_op_params_f32(op, 0);
 769        args.max = ggml_get_op_params_f32(op, 1);
 770    }
 771
 772    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 773
 774    if (pipeline.c4) {
 775        args.ne00 = ne00/4;
 776        args.ne0  = ne0/4;
 777    }
 778
 779    ggml_metal_encoder_set_pipeline(enc, pipeline);
 780    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 781    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
 782    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 783
 784    if (pipeline.cnt) {
 785        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
 786
 787        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
 788    } else {
 789        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 790
 791        const int nth = MIN(args.ne00, nth_max);
 792
 793        const int nk0 = (args.ne00 + nth - 1)/nth;
 794
 795        ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
 796    }
 797
 798    return 1;
 799}
 800
 801int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
 802    ggml_tensor * op = ctx->node(idx);
 803
 804    ggml_metal_library_t lib = ctx->lib;
 805    ggml_metal_encoder_t enc = ctx->enc;
 806
 807    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 808    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 809    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 810    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
 811    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 812    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 813
 814    if (op->src[1]) {
 815        GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
 816    }
 817
 818    auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
 819
 820    const int32_t swp = ggml_get_op_params_i32(op, 1);
 821    const float alpha = ggml_get_op_params_f32(op, 2);
 822    const float limit = ggml_get_op_params_f32(op, 3);
 823
 824    const int32_t i00 = swp ? ne0 : 0;
 825    const int32_t i10 = swp ? 0 : ne0;
 826
 827    ggml_metal_kargs_glu args = {
 828        /*.ne00 =*/ ne00,
 829        /*.nb01 =*/ nb01,
 830        /*.ne10 =*/ op->src[1] ? ne10 : ne00,
 831        /*.nb11 =*/ op->src[1] ? nb11 : nb01,
 832        /*.ne0  =*/ ne0,
 833        /*.nb1  =*/ nb1,
 834        /*.i00  =*/ op->src[1] ? 0 : i00,
 835        /*.i10  =*/ op->src[1] ? 0 : i10,
 836        /*.alpha=*/ alpha,
 837        /*.limit=*/ limit
 838    };
 839
 840    const int64_t nrows = ggml_nrows(op->src[0]);
 841
 842    const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
 843
 844    ggml_metal_encoder_set_pipeline(enc, pipeline);
 845    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 846    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 847    if (op->src[1]) {
 848        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
 849    } else {
 850        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 2);
 851    }
 852    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 853
 854    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
 855
 856    return 1;
 857}
 858
 859int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
 860    ggml_tensor * op  = ctx->node(idx);
 861
 862    ggml_metal_library_t lib = ctx->lib;
 863    ggml_metal_encoder_t enc = ctx->enc;
 864
 865    const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
 866
 867    ggml_metal_kargs_sum args = {
 868        /*.np =*/ n,
 869    };
 870
 871    auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
 872
 873    int nth = 32; // SIMD width
 874
 875    while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
 876        nth *= 2;
 877    }
 878
 879    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 880    nth = std::min(nth, (int) n);
 881
 882    const int nsg = (nth + 31) / 32;
 883
 884    ggml_metal_encoder_set_pipeline(enc, pipeline);
 885    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 886    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 887    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 888
 889    ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
 890
 891    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
 892
 893    return 1;
 894}
 895
 896int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
 897    ggml_tensor * op = ctx->node(idx);
 898
 899    ggml_metal_library_t lib = ctx->lib;
 900    ggml_metal_encoder_t enc = ctx->enc;
 901
 902    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 903    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 904    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 905    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 906
 907    ggml_metal_kargs_sum_rows args = {
 908        /*.ne00 =*/ ne00,
 909        /*.ne01 =*/ ne01,
 910        /*.ne02 =*/ ne02,
 911        /*.ne03 =*/ ne03,
 912        /*.nb00 =*/ nb00,
 913        /*.nb01 =*/ nb01,
 914        /*.nb02 =*/ nb02,
 915        /*.nb03 =*/ nb03,
 916        /*.ne0  =*/ ne0,
 917        /*.ne1  =*/ ne1,
 918        /*.ne2  =*/ ne2,
 919        /*.ne3  =*/ ne3,
 920        /*.nb0  =*/ nb0,
 921        /*.nb1  =*/ nb1,
 922        /*.nb2  =*/ nb2,
 923        /*.nb3  =*/ nb3,
 924    };
 925
 926    auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
 927
 928    int nth = 32; // SIMD width
 929
 930    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
 931        nth *= 2;
 932    }
 933
 934    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 935    nth = std::min(nth, ne00);
 936
 937    const size_t smem = pipeline.smem;
 938
 939    ggml_metal_encoder_set_pipeline(enc, pipeline);
 940    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
 941    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
 942    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 943
 944    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 945
 946    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 947
 948    return 1;
 949}
 950
 951int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
 952    ggml_tensor * op = ctx->node(idx);
 953
 954    ggml_metal_library_t lib = ctx->lib;
 955    ggml_metal_encoder_t enc = ctx->enc;
 956
 957    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 958
 959    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 960    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
 961    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
 962    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 963
 964    auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
 965
 966    int nth = 1;
 967    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
 968        nth *= 2;
 969    }
 970
 971    GGML_ASSERT(ne00 <= nth*nth);
 972
 973    const int64_t net0 = (ne00 + nth - 1) / nth;
 974    const int64_t net1 = ne01;
 975    const int64_t net2 = ne02;
 976    const int64_t net3 = ne03;
 977
 978    const uint64_t nbt0 = sizeof(float);
 979    const uint64_t nbt1 = net0*nbt0;
 980    const uint64_t nbt2 = net1*nbt1;
 981    const uint64_t nbt3 = net2*nbt2;
 982
 983    const size_t smem = GGML_PAD(32*sizeof(float), 16);
 984
 985    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
 986    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
 987
 988    ggml_metal_buffer_id bid_tmp = bid_dst;
 989    bid_tmp.offs += ggml_nbytes(op);
 990
 991    {
 992        ggml_metal_kargs_cumsum_blk args = {
 993            /*.ne00 =*/ ne00,
 994            /*.ne01 =*/ ne01,
 995            /*.ne02 =*/ ne02,
 996            /*.ne03 =*/ ne03,
 997            /*.nb00 =*/ nb00,
 998            /*.nb01 =*/ nb01,
 999            /*.nb02 =*/ nb02,
1000            /*.nb03 =*/ nb03,
1001            /*.net0 =*/ net0,
1002            /*.net1 =*/ net1,
1003            /*.net2 =*/ net2,
1004            /*.net3 =*/ net3,
1005            /*.nbt0 =*/ nbt0,
1006            /*.nbt1 =*/ nbt1,
1007            /*.nbt2 =*/ nbt2,
1008            /*.nbt3 =*/ nbt3,
1009            /*.outb =*/ ne00 > nth,
1010        };
1011
1012        ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1013        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1014        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
1015        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  2);
1016        ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
1017
1018        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1019
1020        ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1021    }
1022
1023    if (ne00 > nth) {
1024        ggml_metal_op_concurrency_reset(ctx);
1025
1026        {
1027            ggml_metal_kargs_cumsum_blk args = {
1028                /*.ne00 =*/ net0,
1029                /*.ne01 =*/ net1,
1030                /*.ne02 =*/ net2,
1031                /*.ne03 =*/ net3,
1032                /*.nb00 =*/ nbt0,
1033                /*.nb01 =*/ nbt1,
1034                /*.nb02 =*/ nbt2,
1035                /*.nb03 =*/ nbt3,
1036                /*.net0 =*/ net0,
1037                /*.net1 =*/ net1,
1038                /*.net2 =*/ net2,
1039                /*.net3 =*/ net3,
1040                /*.nbt0 =*/ nbt0,
1041                /*.nbt1 =*/ nbt1,
1042                /*.nbt2 =*/ nbt2,
1043                /*.nbt3 =*/ nbt3,
1044                /*.outb =*/ false,
1045            };
1046
1047            ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1048            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1049            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);
1050            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 2);
1051            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 3);
1052
1053            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1054
1055            ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1056        }
1057
1058        ggml_metal_op_concurrency_reset(ctx);
1059
1060        {
1061            auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1062
1063            ggml_metal_kargs_cumsum_add args = {
1064                /*.ne00 =*/ ne00,
1065                /*.ne01 =*/ ne01,
1066                /*.ne02 =*/ ne02,
1067                /*.ne03 =*/ ne03,
1068                /*.nb00 =*/ nb00,
1069                /*.nb01 =*/ nb01,
1070                /*.nb02 =*/ nb02,
1071                /*.nb03 =*/ nb03,
1072                /*.net0 =*/ net0,
1073                /*.net1 =*/ net1,
1074                /*.net2 =*/ net2,
1075                /*.net3 =*/ net3,
1076                /*.nbt0 =*/ nbt0,
1077                /*.nbt1 =*/ nbt1,
1078                /*.nbt2 =*/ nbt2,
1079                /*.nbt3 =*/ nbt3,
1080            };
1081
1082            ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1083            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1084            ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);
1085            ggml_metal_encoder_set_buffer  (enc, bid_dst, 2);
1086
1087            ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1088        }
1089    }
1090
1091    return 1;
1092}
1093
1094int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
1095    ggml_tensor * op = ctx->node(idx);
1096
1097    ggml_metal_library_t lib = ctx->lib;
1098    ggml_metal_encoder_t enc = ctx->enc;
1099
1100    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1101    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1102    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1103    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1104    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1105    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1106
1107    auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1108
1109    ggml_metal_kargs_get_rows args = {
1110        /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1111        /*.ne00  =*/ ne00,
1112        /*.nb01  =*/ nb01,
1113        /*.nb02  =*/ nb02,
1114        /*.nb03  =*/ nb03,
1115        /*.ne10  =*/ ne10,
1116        /*.nb10  =*/ nb10,
1117        /*.nb11  =*/ nb11,
1118        /*.nb12  =*/ nb12,
1119        /*.nb1   =*/ nb1,
1120        /*.nb2   =*/ nb2,
1121        /*.nb3   =*/ nb3,
1122    };
1123
1124    const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1125
1126    const int nw0 = (args.ne00t + nth - 1)/nth;
1127
1128    ggml_metal_encoder_set_pipeline(enc, pipeline);
1129    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1130    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1131    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1132    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1133
1134    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
1135
1136    return 1;
1137}
1138
1139int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1140    ggml_tensor * op = ctx->node(idx);
1141
1142    ggml_metal_library_t lib = ctx->lib;
1143    ggml_metal_encoder_t enc = ctx->enc;
1144
1145    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1146    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1147    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1148    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1149    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1150    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1151
1152    auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1153
1154    const int32_t nk0 = ne0/ggml_blck_size(op->type);
1155
1156    int nth = 32; // SIMD width
1157
1158    while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1159        nth *= 2;
1160    }
1161
1162    int nrptg = 1;
1163    if (nth > nk0) {
1164        nrptg = (nth + nk0 - 1)/nk0;
1165        nth   = nk0;
1166
1167        if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1168            nrptg--;
1169        }
1170    }
1171
1172    nth = std::min(nth, nk0);
1173
1174    ggml_metal_kargs_set_rows args = {
1175        /*.nk0  =*/ nk0,
1176        /*.ne01 =*/ ne01,
1177        /*.nb01 =*/ nb01,
1178        /*.nb02 =*/ nb02,
1179        /*.nb03 =*/ nb03,
1180        /*.ne11 =*/ ne11,
1181        /*.ne12 =*/ ne12,
1182        /*.nb10 =*/ nb10,
1183        /*.nb11 =*/ nb11,
1184        /*.nb12 =*/ nb12,
1185        /*.nb1  =*/ nb1,
1186        /*.nb2  =*/ nb2,
1187        /*.nb3  =*/ nb3,
1188    };
1189
1190    ggml_metal_encoder_set_pipeline(enc, pipeline);
1191    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1192    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1193    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1194    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1195
1196    ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1197
1198    return 1;
1199}
1200
1201int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1202    ggml_tensor * op = ctx->node(idx);
1203
1204    ggml_metal_library_t lib = ctx->lib;
1205    ggml_metal_encoder_t enc = ctx->enc;
1206
1207    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
1208    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1209    GGML_TENSOR_LOCALS(int32_t,  ne, op, ne);
1210    GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1211
1212    ggml_metal_kargs_diag args = {
1213        /*.ne00 =*/ne00,
1214        /*.ne01 =*/ne01,
1215        /*.ne02 =*/ne02,
1216        /*.ne03 =*/ne03,
1217        /*.nb00 =*/nb00,
1218        /*.nb01 =*/nb01,
1219        /*.nb02 =*/nb02,
1220        /*.nb03 =*/nb03,
1221        /*.ne0  =*/ne0,
1222        /*.ne1  =*/ne1,
1223        /*.ne2  =*/ne2,
1224        /*.ne3  =*/ne3,
1225        /*.nb0  =*/nb0,
1226        /*.nb1  =*/nb1,
1227        /*.nb2  =*/nb2,
1228        /*.nb3  =*/nb3,
1229    };
1230
1231    auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1232
1233    ggml_metal_encoder_set_pipeline(enc, pipeline);
1234    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1235    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1236    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         2);
1237
1238    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1239
1240    return 1;
1241}
1242
1243int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1244    ggml_tensor * op = ctx->node(idx);
1245
1246    ggml_metal_library_t lib = ctx->lib;
1247    ggml_metal_encoder_t enc = ctx->enc;
1248
1249    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1250    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1251    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1252    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1253    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1254    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1255    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1256    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1257
1258    float scale;
1259    float max_bias;
1260
1261    memcpy(&scale,    ((const int32_t *) op->op_params) + 0, sizeof(scale));
1262    memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
1263
1264    const uint32_t n_head      = op->src[0]->ne[2];
1265    const  int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1266
1267    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
1268    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1269
1270    // softmax
1271
1272    ggml_metal_kargs_soft_max args = {
1273        /*.ne00        =*/ ne00,
1274        /*.ne01        =*/ ne01,
1275        /*.ne02        =*/ ne02,
1276        /*.nb01        =*/ nb01,
1277        /*.nb02        =*/ nb02,
1278        /*.nb03        =*/ nb03,
1279        /*.ne11        =*/ ne11,
1280        /*.ne12        =*/ ne12,
1281        /*.ne13        =*/ ne13,
1282        /*.nb11        =*/ nb11,
1283        /*.nb12        =*/ nb12,
1284        /*.nb13        =*/ nb13,
1285        /*.nb1         =*/ nb1,
1286        /*.nb2         =*/ nb2,
1287        /*.nb3         =*/ nb3,
1288        /*.scale       =*/ scale,
1289        /*.max_bias    =*/ max_bias,
1290        /*.m0          =*/ m0,
1291        /*.m1          =*/ m1,
1292        /*.n_head_log2 =*/ n_head_log2,
1293    };
1294
1295    auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1296
1297    int nth = 32; // SIMD width
1298
1299    if (ne00%4 == 0) {
1300        while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1301            nth *= 2;
1302        }
1303    } else {
1304        while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1305            nth *= 2;
1306        }
1307    }
1308
1309    const size_t smem = pipeline.smem;
1310
1311    ggml_metal_encoder_set_pipeline(enc, pipeline);
1312    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1313    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1314    if (op->src[1]) {
1315        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1316    } else {
1317        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
1318    }
1319    if (op->src[2]) {
1320        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1321    } else {
1322        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
1323    }
1324    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
1325
1326    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1327
1328    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1329
1330    return 1;
1331}
1332
1333int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1334    ggml_tensor * op = ctx->node(idx);
1335
1336    ggml_metal_library_t lib = ctx->lib;
1337    ggml_metal_encoder_t enc = ctx->enc;
1338
1339    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1340    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1341    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1342    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1343    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1344    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1345
1346    ggml_metal_kargs_ssm_conv args = {
1347        /*.ne00 =*/ ne00,
1348        /*.ne01 =*/ ne01,
1349        /*.ne02 =*/ ne02,
1350        /*.nb00 =*/ nb00,
1351        /*.nb01 =*/ nb01,
1352        /*.nb02 =*/ nb02,
1353        /*.ne10 =*/ ne10,
1354        /*.ne11 =*/ ne11,
1355        /*.nb10 =*/ nb10,
1356        /*.nb11 =*/ nb11,
1357        /*.ne0  =*/ ne0,
1358        /*.ne1  =*/ ne1,
1359        /*.ne2  =*/ ne2,
1360        /*.nb0  =*/ nb0,
1361        /*.nb1  =*/ nb1,
1362        /*.nb2  =*/ nb2,
1363    };
1364
1365    // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1366    const bool use_batched = (ne1 > 1);
1367
1368    if (use_batched) {
1369        // Determine the smallest power of 2 that's >= ne1, but <= 256
1370        int BATCH_SIZE;
1371        if      (ne1 > 128) BATCH_SIZE = 256;
1372        else if (ne1 > 64 ) BATCH_SIZE = 128;
1373        else if (ne1 > 32 ) BATCH_SIZE = 64;
1374        else if (ne1 > 16 ) BATCH_SIZE = 32;
1375        else if (ne1 > 8  ) BATCH_SIZE = 16;
1376        else if (ne1 > 4  ) BATCH_SIZE = 8;
1377        else                BATCH_SIZE = 2;
1378
1379        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1380
1381        ggml_metal_encoder_set_pipeline(enc, pipeline);
1382        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1383        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1384        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1385        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
1386
1387        // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1388        // Each threadgroup has BATCH_SIZE threads, each handling one token
1389        const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1390        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1391    } else {
1392        auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1393
1394        ggml_metal_encoder_set_pipeline(enc, pipeline);
1395        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1396        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1397        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1398        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
1399
1400        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1401    }
1402
1403    return 1;
1404}
1405
1406int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1407    ggml_tensor * op = ctx->node(idx);
1408
1409    ggml_metal_library_t lib = ctx->lib;
1410    ggml_metal_encoder_t enc = ctx->enc;
1411
1412    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1413    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1414    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1415    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1416    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1417    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1418    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1419    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1420    GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
1421    GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
1422    GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
1423    GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
1424    GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1425    GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1426    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1427    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1428
1429    const ggml_tensor * src3 = op->src[3];
1430    const ggml_tensor * src4 = op->src[4];
1431    const ggml_tensor * src5 = op->src[5];
1432    const ggml_tensor * src6 = op->src[6];
1433
1434    GGML_ASSERT(src3);
1435    GGML_ASSERT(src4);
1436    GGML_ASSERT(src5);
1437    GGML_ASSERT(src6);
1438
1439    const int64_t d_state      = ne00;
1440    const int64_t d_inner      = ne01;
1441    const int64_t n_head       = ne02;
1442    const int64_t n_group      = ne41;
1443    const int64_t n_seq_tokens = ne12;
1444    const int64_t n_seqs       = ne13;
1445
1446    ggml_metal_kargs_ssm_scan args = {
1447        /*.d_state      =*/ d_state,
1448        /*.d_inner      =*/ d_inner,
1449        /*.n_head       =*/ n_head,
1450        /*.n_group      =*/ n_group,
1451        /*.n_seq_tokens =*/ n_seq_tokens,
1452        /*.n_seqs       =*/ n_seqs,
1453        /*.s_off        =*/ ggml_nelements(op->src[1]) * sizeof(float),
1454        /*.nb00         =*/ nb00,
1455        /*.nb01         =*/ nb01,
1456        /*.nb02         =*/ nb02,
1457        /*.nb03         =*/ nb03,
1458        /*.nb10         =*/ nb10,
1459        /*.nb11         =*/ nb11,
1460        /*.nb12         =*/ nb12,
1461        /*.ns12         =*/ nb12/nb10,
1462        /*.nb13         =*/ nb13,
1463        /*.nb20         =*/ nb20,
1464        /*.nb21         =*/ nb21,
1465        /*.ns21         =*/ nb21/nb20,
1466        /*.nb22         =*/ nb22,
1467        /*.ne30         =*/ ne30,
1468        /*.nb31         =*/ nb31,
1469        /*.nb41         =*/ nb41,
1470        /*.nb42         =*/ nb42,
1471        /*.ns42         =*/ nb42/nb40,
1472        /*.nb43         =*/ nb43,
1473        /*.nb51         =*/ nb51,
1474        /*.nb52         =*/ nb52,
1475        /*.ns52         =*/ nb52/nb50,
1476        /*.nb53         =*/ nb53,
1477        /*.nb0          =*/ nb0,
1478    };
1479
1480    auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1481
1482    GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1483
1484    const size_t smem = pipeline.smem;
1485
1486    ggml_metal_encoder_set_pipeline(enc, pipeline);
1487    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1488    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1489    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1490    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1491    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), 4);
1492    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), 5);
1493    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), 6);
1494    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1495    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         8);
1496
1497    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1498
1499    ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1500
1501    return 1;
1502}
1503
1504int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1505    ggml_tensor * op = ctx->node(idx);
1506
1507    ggml_metal_library_t lib = ctx->lib;
1508    ggml_metal_encoder_t enc = ctx->enc;
1509
1510    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1511    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1512    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1513    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1514
1515    const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1516    const int64_t T = op->src[0]->ne[2];
1517    const int64_t C = op->ne[0];
1518    const int64_t H = op->src[0]->ne[1];
1519
1520    auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1521
1522    int ida = 0;
1523
1524    ggml_metal_encoder_set_pipeline(enc, pipeline);
1525    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1526    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1527    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1528    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1529    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1530    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1531    if (op->op == GGML_OP_RWKV_WKV7) {
1532        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1533    }
1534    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++);
1535    ggml_metal_encoder_set_bytes   (enc, (void *) &B, sizeof(B), ida++);
1536    ggml_metal_encoder_set_bytes   (enc, (void *) &T, sizeof(T), ida++);
1537    ggml_metal_encoder_set_bytes   (enc, (void *) &C, sizeof(C), ida++);
1538    ggml_metal_encoder_set_bytes   (enc, (void *) &H, sizeof(H), ida++);
1539
1540    ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1541
1542    return 1;
1543}
1544
1545int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1546    ggml_tensor * op = ctx->node(idx);
1547
1548    ggml_metal_library_t lib = ctx->lib;
1549    ggml_metal_encoder_t enc = ctx->enc;
1550
1551    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1552    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1553    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1554    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1555    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1556    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1557
1558    ggml_metal_kargs_solve_tri args = {
1559        /*.ne00 =*/ ne00,
1560        /*.ne01 =*/ ne01,
1561        /*.ne02 =*/ ne02,
1562        /*.ne03 =*/ ne03,
1563        /*.nb00 =*/ nb00,
1564        /*.nb01 =*/ nb01,
1565        /*.nb02 =*/ nb02,
1566        /*.nb03 =*/ nb03,
1567        /*.ne10 =*/ ne10,
1568        /*.ne11 =*/ ne11,
1569        /*.ne12 =*/ ne12,
1570        /*.ne13 =*/ ne13,
1571        /*.nb10 =*/ nb10,
1572        /*.nb11 =*/ nb11,
1573        /*.nb12 =*/ nb12,
1574        /*.nb13 =*/ nb13,
1575        /*.ne0  =*/ ne0,
1576        /*.ne1  =*/ ne1,
1577        /*.ne2  =*/ ne2,
1578        /*.ne3  =*/ ne3,
1579        /*.nb0  =*/ nb0,
1580        /*.nb1  =*/ nb1,
1581        /*.nb2  =*/ nb2,
1582        /*.nb3  =*/ nb3,
1583    };
1584
1585    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1586
1587    ggml_metal_encoder_set_pipeline(enc, pipeline);
1588    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1589    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1590    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1591    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1592
1593    const int nsg = pipeline.nsg;
1594
1595    ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1596
1597    ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1598
1599    return 1;
1600}
1601
1602int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1603    ggml_tensor * op = ctx->node(idx);
1604
1605    ggml_metal_library_t lib = ctx->lib;
1606    ggml_metal_encoder_t enc = ctx->enc;
1607
1608    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1609    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1610    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1611    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1612
1613    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1614
1615    GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1616
1617    int64_t nk0 = ne00;
1618    if (ggml_is_quantized(op->src[0]->type)) {
1619        nk0 = ne00/16;
1620    } else if (ggml_is_quantized(op->type)) {
1621        nk0 = ne00/ggml_blck_size(op->type);
1622    }
1623
1624    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1625
1626    // when rows are small, we can batch them together in a single threadgroup
1627    int nrptg = 1;
1628
1629    // TODO: relax this constraint in the future
1630    if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1631        if (nth > nk0) {
1632            nrptg = (nth + nk0 - 1)/nk0;
1633            nth   = nk0;
1634
1635            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1636                nrptg--;
1637            }
1638        }
1639    }
1640
1641    nth = std::min<int>(nth, nk0);
1642
1643    ggml_metal_kargs_cpy args = {
1644        /*.nk0  =*/ nk0,
1645        /*.ne00 =*/ ne00,
1646        /*.ne01 =*/ ne01,
1647        /*.ne02 =*/ ne02,
1648        /*.ne03 =*/ ne03,
1649        /*.nb00 =*/ nb00,
1650        /*.nb01 =*/ nb01,
1651        /*.nb02 =*/ nb02,
1652        /*.nb03 =*/ nb03,
1653        /*.ne0  =*/ ne0,
1654        /*.ne1  =*/ ne1,
1655        /*.ne2  =*/ ne2,
1656        /*.ne3  =*/ ne3,
1657        /*.nb0  =*/ nb0,
1658        /*.nb1  =*/ nb1,
1659        /*.nb2  =*/ nb2,
1660        /*.nb3  =*/ nb3,
1661    };
1662
1663    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1664
1665    ggml_metal_encoder_set_pipeline(enc, pipeline);
1666    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1667    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1668    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
1669
1670    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1671
1672    return 1;
1673}
1674
1675int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1676    ggml_tensor * op = ctx->node(idx);
1677
1678    ggml_metal_library_t lib = ctx->lib;
1679    ggml_metal_encoder_t enc = ctx->enc;
1680
1681    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1682    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1683    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1684    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1685
1686    const int32_t * opts = op->op_params;
1687    ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1688
1689    const int32_t k0 = opts[1];
1690    const int32_t s0 = opts[2];
1691    const int32_t p0 = opts[3];
1692
1693    const int64_t IW = op->src[0]->ne[0];
1694    const int64_t OW = op->ne[0];
1695
1696    const int64_t np = ggml_nelements(op);
1697
1698    ggml_metal_kargs_pool_1d args_pool_1d = {
1699        /* .k0 = */  k0,
1700        /* .s0 = */  s0,
1701        /* .p0 = */  p0,
1702        /* .IW = */  IW,
1703        /* .OW = */  OW,
1704        /* .np = */  np
1705    };
1706
1707    auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1708
1709    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1710    const int ntg = (np + nth - 1) / nth;
1711
1712    ggml_metal_encoder_set_pipeline(enc, pipeline);
1713    ggml_metal_encoder_set_bytes   (enc, &args_pool_1d, sizeof(args_pool_1d),  0);
1714    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1715    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
1716
1717    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1718
1719    return 1;
1720}
1721
1722
1723int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1724    ggml_tensor * op = ctx->node(idx);
1725
1726    ggml_metal_library_t lib = ctx->lib;
1727    ggml_metal_encoder_t enc = ctx->enc;
1728
1729    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1730    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1731    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1732    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1733
1734    const int32_t * opts = op->op_params;
1735    ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1736
1737    const int32_t k0 = opts[1];
1738    const int32_t k1 = opts[2];
1739    const int32_t s0 = opts[3];
1740    const int32_t s1 = opts[4];
1741    const int32_t p0 = opts[5];
1742    const int32_t p1 = opts[6];
1743
1744    const int64_t IH = op->src[0]->ne[1];
1745    const int64_t IW = op->src[0]->ne[0];
1746
1747    const int64_t N  = op->ne[3];
1748    const int64_t OC = op->ne[2];
1749    const int64_t OH = op->ne[1];
1750    const int64_t OW = op->ne[0];
1751
1752    const int64_t np = N * OC * OH * OW;
1753
1754    ggml_metal_kargs_pool_2d args_pool_2d = {
1755        /* .k0 = */ k0,
1756        /* .k1 = */ k1,
1757        /* .s0 = */ s0,
1758        /* .s1 = */ s1,
1759        /* .p0 = */ p0,
1760        /* .p1 = */ p1,
1761        /* .IH = */ IH,
1762        /* .IW = */ IW,
1763        /* .OH = */ OH,
1764        /* .OW = */ OW,
1765        /* .np = */ np
1766    };
1767
1768    auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1769
1770    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1771    const int ntg = (np + nth - 1) / nth;
1772
1773    ggml_metal_encoder_set_pipeline(enc, pipeline);
1774    ggml_metal_encoder_set_bytes   (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
1775    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1776    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
1777
1778    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1779
1780    return 1;
1781}
1782
1783int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1784    ggml_tensor * op = ctx->node(idx);
1785
1786    ggml_metal_library_t lib = ctx->lib;
1787    ggml_metal_encoder_t enc = ctx->enc;
1788
1789    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
1790
1791    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1792    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1793    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1794    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1795    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
1796    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
1797
1798    GGML_ASSERT(ne00 == ne10);
1799
1800    GGML_ASSERT(ne12 % ne02 == 0);
1801    GGML_ASSERT(ne13 % ne03 == 0);
1802
1803    const int16_t r2 = ne12/ne02;
1804    const int16_t r3 = ne13/ne03;
1805
1806    // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1807    // to the matrix-vector kernel
1808    const int ne11_mm_min = 8;
1809
1810    // first try to use small-batch mat-mv kernels
1811    // these should be efficient for BS [2, ~8]
1812    if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&
1813        (
1814         (
1815          (
1816           op->src[0]->type == GGML_TYPE_F32  || // TODO: helper function
1817           op->src[0]->type == GGML_TYPE_F16  ||
1818           op->src[0]->type == GGML_TYPE_Q4_0 ||
1819           op->src[0]->type == GGML_TYPE_Q4_1 ||
1820           op->src[0]->type == GGML_TYPE_Q5_0 ||
1821           op->src[0]->type == GGML_TYPE_Q5_1 ||
1822           op->src[0]->type == GGML_TYPE_Q8_0 ||
1823           op->src[0]->type == GGML_TYPE_MXFP4 ||
1824           op->src[0]->type == GGML_TYPE_IQ4_NL ||
1825           false) && (ne11 >= 2 && ne11 <= 8)
1826         ) ||
1827         (
1828          (
1829           op->src[0]->type == GGML_TYPE_Q4_K ||
1830           op->src[0]->type == GGML_TYPE_Q5_K ||
1831           op->src[0]->type == GGML_TYPE_Q6_K ||
1832           false) && (ne11 >= 4 && ne11 <= 8)
1833         )
1834        )
1835       ) {
1836        // TODO: determine the optimal parameters based on grid utilization
1837        //       I still don't know why we should not always use the maximum available threads:
1838        //
1839        //       nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
1840        //
1841        //       my current hypothesis is that the work grid is not evenly divisible for different nsg
1842        //       values and there can be some tail effects when nsg is high. need to confirm this
1843        //
1844        const int nsg    = 2;                 // num simdgroups per threadgroup
1845
1846        // num threads along row per simdgroup
1847        int16_t nxpsg = 0;
1848        if (ne00 % 256 == 0 && ne11 < 3) {
1849            nxpsg = 16;
1850        } else if (ne00 % 128 == 0) {
1851            nxpsg = 8;
1852        } else {
1853            nxpsg = 4;
1854        }
1855
1856        const int16_t nypsg  = 32/nxpsg;          // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
1857        const int16_t r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup
1858              int16_t r1ptg  = 4;                 // num src1 rows per threadgroup
1859
1860        // note: not sure how optimal are those across all different hardware. there might be someting cleverer
1861        switch (ne11) {
1862            case 2:
1863                r1ptg = 2; break;
1864            case 3:
1865            case 6:
1866                r1ptg = 3; break;
1867            case 4:
1868            case 7:
1869            case 8:
1870                r1ptg = 4; break;
1871            case 5:
1872                r1ptg = 5; break;
1873            default:
1874                GGML_ABORT("unsupported ne11");
1875        };
1876
1877        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1878
1879        ggml_metal_kargs_mul_mv_ext args = {
1880            /*.ne00  =*/ ne00,
1881            /*.ne01  =*/ ne01,
1882            /*.ne02  =*/ ne02,
1883            /*.nb00  =*/ nb00,
1884            /*.nb01  =*/ nb01,
1885            /*.nb02  =*/ nb02,
1886            /*.nb03  =*/ nb03,
1887            /*.ne10  =*/ ne10,
1888            /*.ne11  =*/ ne11,
1889            /*.ne12  =*/ ne12,
1890            /*.nb10  =*/ nb10,
1891            /*.nb11  =*/ nb11,
1892            /*.nb12  =*/ nb12,
1893            /*.nb13  =*/ nb13,
1894            /*.ne0   =*/ ne0,
1895            /*.ne1   =*/ ne1,
1896            /*.r2    =*/ r2,
1897            /*.r3    =*/ r3,
1898        };
1899
1900        ggml_metal_encoder_set_pipeline(enc, pipeline);
1901        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1902        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1903        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1904        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1905
1906        ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
1907    } else if (
1908        !ggml_is_transposed(op->src[0]) &&
1909        !ggml_is_transposed(op->src[1]) &&
1910        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1911        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1912        props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1913        //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1914
1915        // some Metal matrix data types require aligned pointers
1916        // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1917        //switch (op->src[0]->type) {
1918        //    case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
1919        //    case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
1920        //    case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
1921        //    default: break;
1922        //}
1923
1924        auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1925
1926        ggml_metal_kargs_mul_mm args = {
1927            /*.ne00 =*/ ne00,
1928            /*.ne02 =*/ ne02,
1929            /*.nb01 =*/ nb01,
1930            /*.nb02 =*/ nb02,
1931            /*.nb03 =*/ nb03,
1932            /*.ne12 =*/ ne12,
1933            /*.nb10 =*/ nb10,
1934            /*.nb11 =*/ nb11,
1935            /*.nb12 =*/ nb12,
1936            /*.nb13 =*/ nb13,
1937            /*.ne0  =*/ ne0,
1938            /*.ne1  =*/ ne1,
1939            /*.r2   =*/ r2,
1940            /*.r3   =*/ r3,
1941        };
1942
1943        ggml_metal_encoder_set_pipeline(enc, pipeline);
1944        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1945        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1946        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1947        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1948
1949        const size_t smem = pipeline.smem;
1950
1951        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1952        ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1953    } else {
1954        auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1955
1956        const int nr0 = pipeline.nr0;
1957        const int nr1 = pipeline.nr1;
1958        const int nsg = pipeline.nsg;
1959
1960        const size_t smem = pipeline.smem;
1961
1962        ggml_metal_kargs_mul_mv args = {
1963            /*.ne00 =*/ ne00,
1964            /*.ne01 =*/ ne01,
1965            /*.ne02 =*/ ne02,
1966            /*.nb00 =*/ nb00,
1967            /*.nb01 =*/ nb01,
1968            /*.nb02 =*/ nb02,
1969            /*.nb03 =*/ nb03,
1970            /*.ne10 =*/ ne10,
1971            /*.ne11 =*/ ne11,
1972            /*.ne12 =*/ ne12,
1973            /*.nb10 =*/ nb10,
1974            /*.nb11 =*/ nb11,
1975            /*.nb12 =*/ nb12,
1976            /*.nb13 =*/ nb13,
1977            /*.ne0  =*/ ne0,
1978            /*.ne1  =*/ ne1,
1979            /*.nr0  =*/ nr0,
1980            /*.r2   =*/ r2,
1981            /*.r3   =*/ r3,
1982        };
1983
1984        ggml_metal_encoder_set_pipeline(enc, pipeline);
1985        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
1986        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1987        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1988        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
1989
1990        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1991
1992        if (op->src[0]->type == GGML_TYPE_F32 ||
1993            op->src[0]->type == GGML_TYPE_F16 ||
1994            op->src[0]->type == GGML_TYPE_BF16 ||
1995            op->src[0]->type == GGML_TYPE_Q8_0) {
1996            ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1997        } else {
1998            ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1999        }
2000    }
2001
2002    return 1;
2003}
2004
2005size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {
2006    assert(op->op == GGML_OP_MUL_MAT_ID);
2007
2008    const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2009
2010    return ggml_type_size(GGML_TYPE_I32)*ne02;
2011}
2012
2013size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
2014    assert(op->op == GGML_OP_MUL_MAT_ID);
2015
2016    const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2017    const int64_t ne21 = op->src[2]->ne[1]; // n_token
2018
2019    return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
2020}
2021
2022int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
2023    ggml_tensor * op = ctx->node(idx);
2024
2025    ggml_metal_library_t lib = ctx->lib;
2026    ggml_metal_encoder_t enc = ctx->enc;
2027
2028    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2029
2030    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2031    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2032    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2033    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2034    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2035    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2036    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
2037    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
2038
2039    // src2 = ids
2040    GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2041
2042    GGML_ASSERT(!ggml_is_transposed(op->src[0]));
2043    GGML_ASSERT(!ggml_is_transposed(op->src[1]));
2044
2045    GGML_ASSERT(ne03 == 1);
2046    GGML_ASSERT(ne13 == 1);
2047
2048    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2049    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2050    ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2051    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
2052
2053    const uint32_t r2 = 1;
2054    const uint32_t r3 = 1;
2055
2056    // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2057    // to the matrix-vector kernel
2058    // ne20 = n_used_experts
2059    // ne21 = n_rows (batch size)
2060    const int ne21_mm_id_min = 32;
2061
2062    if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
2063        // some Metal matrix data types require aligned pointers
2064        // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2065        //switch (op->src[0]->type) {
2066        //    case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
2067        //    case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
2068        //    case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
2069        //    default: break;
2070        //}
2071
2072        // extra buffers for intermediate id mapping
2073        ggml_metal_buffer_id bid_tpe = bid_dst;
2074        bid_tpe.offs += ggml_nbytes(op);
2075
2076        ggml_metal_buffer_id bid_ids = bid_tpe;
2077        bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);
2078
2079        {
2080            ggml_metal_kargs_mul_mm_id_map0 args = {
2081                ne02,
2082                ne10,
2083                ne11, // n_expert_used (bcast)
2084                nb11,
2085                nb12,
2086                ne21, // n_tokens
2087                ne20, // n_expert_used
2088                nb21,
2089            };
2090
2091            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
2092
2093            const size_t smem = pipeline.smem;
2094
2095            GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2096
2097            GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2098
2099            ggml_metal_encoder_set_pipeline(enc, pipeline);
2100            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2101            ggml_metal_encoder_set_buffer  (enc, bid_src2, 1);
2102            ggml_metal_encoder_set_buffer  (enc, bid_tpe,  2);
2103            ggml_metal_encoder_set_buffer  (enc, bid_ids,  3);
2104
2105            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2106
2107            ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
2108        }
2109
2110        // this barrier is always needed because the next kernel has to wait for the id maps to be computed
2111        ggml_metal_op_concurrency_reset(ctx);
2112
2113        {
2114            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2115
2116            ggml_metal_kargs_mul_mm_id args = {
2117                /*.ne00  =*/ ne00,
2118                /*.ne02  =*/ ne02,
2119                /*.nb01  =*/ nb01,
2120                /*.nb02  =*/ nb02,
2121                /*.nb03  =*/ nb03,
2122                /*.ne11  =*/ ne11, // n_expert_used (bcast)
2123                /*.nb10  =*/ nb10,
2124                /*.nb11  =*/ nb11,
2125                /*.nb12  =*/ nb12,
2126                /*.nb13  =*/ nb13,
2127                /*.ne20  =*/ ne20, // n_expert_used
2128                /*.ne21  =*/ ne21, // n_tokens
2129                /*.ne0   =*/ ne0,
2130                /*.ne1   =*/ ne1,
2131                /*.r2    =*/ r2,
2132                /*.r3    =*/ r3,
2133            };
2134
2135            ggml_metal_encoder_set_pipeline(enc, pipeline);
2136            ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2137            ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
2138            ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
2139            ggml_metal_encoder_set_buffer  (enc, bid_tpe,  3);
2140            ggml_metal_encoder_set_buffer  (enc, bid_ids,  4);
2141            ggml_metal_encoder_set_buffer  (enc, bid_dst,  5);
2142
2143            const size_t smem = pipeline.smem;
2144
2145            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2146
2147            ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
2148        }
2149    } else {
2150        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2151
2152        const int nr0 = pipeline.nr0;
2153        const int nr1 = pipeline.nr1;
2154        const int nsg = pipeline.nsg;
2155
2156        const size_t smem = pipeline.smem;
2157
2158        ggml_metal_kargs_mul_mv_id args = {
2159            /*.nei0 =*/ ne20,
2160            /*.nei1 =*/ ne21,
2161            /*.nbi1 =*/ nb21,
2162            /*.ne00 =*/ ne00,
2163            /*.ne01 =*/ ne01,
2164            /*.ne02 =*/ ne02,
2165            /*.nb00 =*/ nb00,
2166            /*.nb01 =*/ nb01,
2167            /*.nb02 =*/ nb02,
2168            /*.ne10 =*/ ne10,
2169            /*.ne11 =*/ ne11,
2170            /*.ne12 =*/ ne12,
2171            /*.ne13 =*/ ne13,
2172            /*.nb10 =*/ nb10,
2173            /*.nb11 =*/ nb11,
2174            /*.nb12 =*/ nb12,
2175            /*.ne0  =*/ ne0,
2176            /*.ne1  =*/ ne1,
2177            /*.nb1  =*/ nb1,
2178            /*.nr0  =*/ nr0,
2179        };
2180
2181        if (ggml_is_quantized(op->src[0]->type)) {
2182            GGML_ASSERT(ne00 >= nsg*nr0);
2183        }
2184
2185        ggml_metal_encoder_set_pipeline(enc, pipeline);
2186        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
2187        ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
2188        ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
2189        ggml_metal_encoder_set_buffer(enc, bid_dst,  3);
2190        ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
2191
2192        const int64_t _ne1 = 1;
2193        const int64_t ne123 = ne20*ne21;
2194
2195        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2196
2197        if (op->src[0]->type == GGML_TYPE_F32 ||
2198            op->src[0]->type == GGML_TYPE_F16 ||
2199            op->src[0]->type == GGML_TYPE_BF16 ||
2200            op->src[0]->type == GGML_TYPE_Q8_0) {
2201            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2202        } else {
2203            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2204        }
2205    }
2206
2207    return 1;
2208}
2209
2210int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
2211    ggml_tensor * op = ctx->node(idx);
2212
2213    ggml_metal_library_t lib = ctx->lib;
2214    ggml_metal_encoder_t enc = ctx->enc;
2215
2216    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2217    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2218    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2219    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2220    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2221    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2222    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
2223
2224    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2225    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2226    GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2227    GGML_ASSERT(op->type         == GGML_TYPE_F32);
2228
2229    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2230
2231    ggml_metal_kargs_add_id args = {
2232        /*.ne0  =*/ ne0,
2233        /*.ne1  =*/ ne1,
2234        /*.nb01 =*/ nb01,
2235        /*.nb02 =*/ nb02,
2236        /*.nb11 =*/ nb11,
2237        /*.nb21 =*/ nb21,
2238    };
2239
2240    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2241
2242    ggml_metal_encoder_set_pipeline(enc, pipeline);
2243    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2244    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2245    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2246    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2247    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         4);
2248
2249    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
2250
2251    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
2252
2253    return 1;
2254}
2255
2256bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
2257    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2258
2259    const int64_t ne00 = op->src[0]->ne[0]; // head size
2260    const int64_t ne01 = op->src[0]->ne[1]; // batch size
2261
2262    // use vec kernel if the batch size is small and if the head size is supported
2263    return (ne01 < 20) && (ne00 % 32 == 0);
2264}
2265
2266size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2267    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2268
2269    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2270    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2271    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2272    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2273    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2274    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2275    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2276    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2277
2278    size_t res = 0;
2279
2280    const bool has_mask = op->src[3] != nullptr;
2281
2282    // note: the non-vec kernel requires more extra memory, so always reserve for it
2283    GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2284
2285    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2286    if (false) {
2287        // note: always reserve the padding space to avoid graph reallocations
2288        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2289        const bool has_kvpad = true;
2290
2291        if (has_kvpad) {
2292            res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2293                nb11*ne12*ne13 +
2294                nb21*ne22*ne23 +
2295                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2296        }
2297    } else {
2298        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2299        const bool has_kvpad = true;
2300
2301        if (has_kvpad) {
2302            res += OP_FLASH_ATTN_EXT_NCPSG*(
2303                nb11*ne12*ne13 +
2304                nb21*ne22*ne23 +
2305                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2306        }
2307    }
2308
2309    return res;
2310}
2311
2312size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2313    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2314
2315    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2316  //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2317  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2318  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2319  //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2320  //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2321    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2322    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2323
2324    size_t res = 0;
2325
2326    const bool has_mask = op->src[3] != nullptr;
2327
2328    if (!has_mask) {
2329        return res;
2330    }
2331
2332    const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2333
2334    // this optimization is not useful for the vector kernels
2335    // note: always reserve the blk buffer to avoid graph reallocations
2336    //if (is_vec) {
2337    //    return res;
2338    //}
2339
2340    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2341    const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2342
2343    const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2344    const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2345
2346    res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2347
2348    return res;
2349}
2350
2351size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
2352    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2353
2354    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2355    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2356  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2357  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2358    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2359    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2360  //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2361  //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2362
2363    size_t res = 0;
2364
2365    // note: always reserve the temp buffer to avoid graph reallocations
2366    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2367    if (true) {
2368        const int64_t nwg = 32;
2369        const int64_t ne01_max = std::min(ne01, 32);
2370
2371        // temp buffer for writing the results from each workgroup
2372        // - ne20: the size of the Value head
2373        // -  + 2: the S and M values for each intermediate result
2374        res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2375    }
2376
2377    return res;
2378}
2379
2380int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2381    ggml_tensor * op = ctx->node(idx);
2382
2383    ggml_metal_library_t lib = ctx->lib;
2384    ggml_metal_encoder_t enc = ctx->enc;
2385
2386    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2387
2388    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2389    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2390    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2391    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2392    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2393    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2394    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2395    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2396    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
2397    GGML_TENSOR_LOCALS( int32_t, nb,  op,         nb);
2398
2399    GGML_ASSERT(ne00 % 4 == 0);
2400
2401    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2402    GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2403
2404    //GGML_ASSERT(ggml_are_same_shape (src1, src2));
2405    GGML_ASSERT(ne11 == ne21);
2406    GGML_ASSERT(ne12 == ne22);
2407
2408    GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
2409    GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2410            "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
2411
2412    float scale;
2413    float max_bias;
2414    float logit_softcap;
2415
2416    memcpy(&scale,         ((const int32_t *) op->op_params) + 0, sizeof(scale));
2417    memcpy(&max_bias,      ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
2418    memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
2419
2420    if (logit_softcap != 0.0f) {
2421        scale /= logit_softcap;
2422    }
2423
2424    const bool has_mask  = op->src[3] != NULL;
2425    const bool has_sinks = op->src[4] != NULL;
2426    const bool has_bias  = max_bias != 0.0f;
2427    const bool has_scap  = logit_softcap != 0.0f;
2428
2429    const uint32_t n_head      = op->src[0]->ne[2];
2430    const  int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2431
2432    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
2433    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2434
2435    GGML_ASSERT(ne01 < 65536);
2436
2437    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2438    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2439    ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2440    ggml_metal_buffer_id bid_src3 = has_mask  ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2441    ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2442
2443    ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2444
2445    ggml_metal_buffer_id bid_pad = bid_dst;
2446    bid_pad.offs += ggml_nbytes(op);
2447
2448    ggml_metal_buffer_id bid_blk = bid_pad;
2449    bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2450
2451    ggml_metal_buffer_id bid_tmp = bid_blk;
2452    bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2453
2454    if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2455        // half8x8 kernel
2456        const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2457        const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2458
2459        GGML_ASSERT(nqptg <= 32);
2460        GGML_ASSERT(nqptg  % 8  == 0);
2461        GGML_ASSERT(ncpsg  % 32 == 0);
2462
2463        bool need_sync = false;
2464
2465        const bool has_kvpad = ne11 % ncpsg != 0;
2466
2467        if (has_kvpad) {
2468            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2469
2470            ggml_metal_kargs_flash_attn_ext_pad args0 = {
2471                /*.ne11    =*/ne11,
2472                /*.ne_12_2 =*/ne12,
2473                /*.ne_12_3 =*/ne13,
2474                /*.nb11    =*/nb11,
2475                /*.nb12    =*/nb12,
2476                /*.nb13    =*/nb13,
2477                /*.nb21    =*/nb21,
2478                /*.nb22    =*/nb22,
2479                /*.nb23    =*/nb23,
2480                /*.ne31    =*/ne31,
2481                /*.ne32    =*/ne32,
2482                /*.ne33    =*/ne33,
2483                /*.nb31    =*/nb31,
2484                /*.nb32    =*/nb32,
2485                /*.nb33    =*/nb33,
2486            };
2487
2488            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2489
2490            ggml_metal_encoder_set_pipeline(enc, pipeline0);
2491            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
2492            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
2493            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);
2494            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);
2495            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);
2496
2497            assert(ne12 == ne22);
2498            assert(ne13 == ne23);
2499
2500            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2501
2502            need_sync = true;
2503        }
2504
2505        if (has_mask) {
2506            assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2507
2508            ggml_metal_kargs_flash_attn_ext_blk args0 = {
2509                /*.ne01 =*/ ne01,
2510                /*.ne30 =*/ ne30,
2511                /*.ne31 =*/ ne31,
2512                /*.ne32 =*/ ne32,
2513                /*.ne33 =*/ ne33,
2514                /*.nb31 =*/ nb31,
2515                /*.nb32 =*/ nb32,
2516                /*.nb33 =*/ nb33,
2517            };
2518
2519            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2520
2521            ggml_metal_encoder_set_pipeline(enc, pipeline0);
2522            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
2523            ggml_metal_encoder_set_buffer  (enc, bid_src3, 1);
2524            ggml_metal_encoder_set_buffer  (enc, bid_blk,  2);
2525
2526            const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2527            const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2528
2529            ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2530
2531            need_sync = true;
2532        }
2533
2534        if (need_sync) {
2535            ggml_metal_op_concurrency_reset(ctx);
2536        }
2537
2538        const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
2539
2540        // 2*(2*ncpsg)
2541        // ncpsg soft_max values + ncpsg mask values
2542        //
2543        // 16*32*(nsg)
2544        // the shared memory needed for the simdgroups to load the KV cache
2545        // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
2546        //
2547#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
2548
2549        //int64_t nsgmax = 4;
2550        //
2551        //if (is_q) {
2552        //    nsgmax = 2;
2553        //    while (true) {
2554        //        const size_t smem = FATTN_SMEM(nsgmax);
2555        //        if (smem > props_dev->max_theadgroup_memory_size) {
2556        //            break;
2557        //        }
2558        //        nsgmax *= 2;
2559        //    }
2560        //    nsgmax /= 2;
2561        //}
2562
2563        // simdgroups per threadgroup (a.k.a. warps)
2564        //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2565        int32_t nsg = ne00 >= 512 ? 8 : 4;
2566
2567        const size_t smem = FATTN_SMEM(nsg);
2568
2569        ggml_metal_kargs_flash_attn_ext args = {
2570            /*.ne01          =*/ ne01,
2571            /*.ne02          =*/ ne02,
2572            /*.ne03          =*/ ne03,
2573            /*.nb01          =*/ nb01,
2574            /*.nb02          =*/ nb02,
2575            /*.nb03          =*/ nb03,
2576            /*.ne11          =*/ ne11,
2577            /*.ne_12_2       =*/ ne12,
2578            /*.ne_12_3       =*/ ne13,
2579            /*.ns10          =*/ int32_t(nb11/nb10),
2580            /*.nb11          =*/ nb11,
2581            /*.nb12          =*/ nb12,
2582            /*.nb13          =*/ nb13,
2583            /*.ns20          =*/ int32_t(nb21/nb20),
2584            /*.nb21          =*/ nb21,
2585            /*.nb22          =*/ nb22,
2586            /*.nb23          =*/ nb23,
2587            /*.ne31          =*/ ne31,
2588            /*.ne32          =*/ ne32,
2589            /*.ne33          =*/ ne33,
2590            /*.nb31          =*/ nb31,
2591            /*.nb32          =*/ nb32,
2592            /*.nb33          =*/ nb33,
2593            /*.ne1           =*/ ne1,
2594            /*.ne2           =*/ ne2,
2595            /*.ne3           =*/ ne3,
2596            /*.scale         =*/ scale,
2597            /*.max_bias      =*/ max_bias,
2598            /*.m0            =*/ m0,
2599            /*.m1            =*/ m1,
2600            /*.n_head_log2   =*/ n_head_log2,
2601            /*.logit_softcap =*/ logit_softcap,
2602        };
2603
2604        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2605
2606        ggml_metal_encoder_set_pipeline(enc, pipeline);
2607        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2608        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
2609        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
2610        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);
2611        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);
2612        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);
2613        ggml_metal_encoder_set_buffer  (enc, bid_pad,  6);
2614        ggml_metal_encoder_set_buffer  (enc, bid_blk,  7);
2615        ggml_metal_encoder_set_buffer  (enc, bid_dst,  8);
2616
2617        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2618
2619        ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
2620#undef FATTN_SMEM
2621    } else {
2622        // half4x4 kernel
2623        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2624        const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2625        const int nhptg = 1;                           // heads per threadgroup
2626
2627        GGML_ASSERT(nqptg <= 32);
2628        GGML_ASSERT(nqptg  % 1  == 0);
2629        GGML_ASSERT(ncpsg  % 32 == 0);
2630
2631        bool need_sync = false;
2632
2633        const bool has_kvpad = ne11 % ncpsg != 0;
2634
2635        if (has_kvpad) {
2636            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2637
2638            ggml_metal_kargs_flash_attn_ext_pad args0 = {
2639                /*.ne11    =*/ne11,
2640                /*.ne_12_2 =*/ne12,
2641                /*.ne_12_3 =*/ne13,
2642                /*.nb11    =*/nb11,
2643                /*.nb12    =*/nb12,
2644                /*.nb13    =*/nb13,
2645                /*.nb21    =*/nb21,
2646                /*.nb22    =*/nb22,
2647                /*.nb23    =*/nb23,
2648                /*.ne31    =*/ne31,
2649                /*.ne32    =*/ne32,
2650                /*.ne33    =*/ne33,
2651                /*.nb31    =*/nb31,
2652                /*.nb32    =*/nb32,
2653                /*.nb33    =*/nb33,
2654            };
2655
2656            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2657
2658            ggml_metal_encoder_set_pipeline(enc, pipeline0);
2659            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
2660            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
2661            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);
2662            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);
2663            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);
2664
2665            assert(ne12 == ne22);
2666            assert(ne13 == ne23);
2667
2668            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2669
2670            need_sync = true;
2671        }
2672
2673        if (need_sync) {
2674            ggml_metal_op_concurrency_reset(ctx);
2675        }
2676
2677        // note: for simplicity assume the K is larger or equal than V
2678        GGML_ASSERT(ne10 >= ne20);
2679
2680        // ne00 + 2*ncpsg*(nsg)
2681        // for each query, we load it as f16 in shared memory (ne00)
2682        // and store the soft_max values and the mask
2683        //
2684        // ne20*(nsg)
2685        // each simdgroup has a full f32 head vector in shared mem to accumulate results
2686        //
2687#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2688
2689        int64_t nsg = 1;
2690
2691        // workgroups
2692        // each workgroup handles nsg*nkpsg cache values
2693        int32_t nwg = 1;
2694        if (false) {
2695            // for small KV caches, we could launch a single workgroup and write the results directly to dst/
2696            // however, this does not lead to significant improvement, so disabled
2697            nwg = 1;
2698            nsg = 4;
2699        } else {
2700            nwg = 32;
2701            nsg = 1;
2702            while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2703                nsg *= 2;
2704            }
2705        }
2706
2707        ggml_metal_kargs_flash_attn_ext_vec args = {
2708            /*.ne01          =*/ ne01,
2709            /*.ne02          =*/ ne02,
2710            /*.ne03          =*/ ne03,
2711            /*.nb01          =*/ nb01,
2712            /*.nb02          =*/ nb02,
2713            /*.nb03          =*/ nb03,
2714            /*.ne11          =*/ ne11,
2715            /*.ne_12_2       =*/ ne12,
2716            /*.ne_12_3       =*/ ne13,
2717            /*.ns10          =*/ int32_t(nb11/nb10),
2718            /*.nb11          =*/ nb11,
2719            /*.nb12          =*/ nb12,
2720            /*.nb13          =*/ nb13,
2721            /*.ns20          =*/ int32_t(nb21/nb20),
2722            /*.nb21          =*/ nb21,
2723            /*.nb22          =*/ nb22,
2724            /*.nb23          =*/ nb23,
2725            /*.ne31          =*/ ne31,
2726            /*.ne32          =*/ ne32,
2727            /*.ne33          =*/ ne33,
2728            /*.nb31          =*/ nb31,
2729            /*.nb32          =*/ nb32,
2730            /*.nb33          =*/ nb33,
2731            /*.ne1           =*/ ne1,
2732            /*.ne2           =*/ ne2,
2733            /*.ne3           =*/ ne3,
2734            /*.scale         =*/ scale,
2735            /*.max_bias      =*/ max_bias,
2736            /*.m0            =*/ m0,
2737            /*.m1            =*/ m1,
2738            /*.n_head_log2   =*/ n_head_log2,
2739            /*.logit_softcap =*/ logit_softcap,
2740        };
2741
2742        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2743
2744        GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2745
2746        ggml_metal_encoder_set_pipeline(enc, pipeline);
2747        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2748        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
2749        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
2750        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);
2751        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);
2752        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);
2753
2754        const size_t smem = FATTN_SMEM(nsg);
2755
2756        //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2757        GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2758
2759        if (nwg == 1) {
2760            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2761
2762            // using 1 workgroup -> write the result directly into dst
2763            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2764            ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2765
2766            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2767
2768            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2769        } else {
2770            // sanity checks
2771            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2772
2773            GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2774            GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2775
2776            // write the results from each workgroup into a temp buffer
2777            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2778            ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2779
2780            ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2781            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2782
2783            // sync the 2 kernels
2784            ggml_metal_op_concurrency_reset(ctx);
2785
2786            // reduce the results from the workgroups
2787            {
2788                const int32_t nrows = ne1*ne2*ne3;
2789
2790                ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
2791                    nrows,
2792                };
2793
2794                auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2795
2796                ggml_metal_encoder_set_pipeline(enc, pipeline0);
2797                ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
2798                ggml_metal_encoder_set_buffer  (enc, bid_tmp, 1);
2799                ggml_metal_encoder_set_buffer  (enc, bid_dst, 2);
2800
2801                ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
2802            }
2803        }
2804#undef FATTN_SMEM
2805    }
2806
2807    return 1;
2808}
2809
2810int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2811    ggml_tensor * op = ctx->node(idx);
2812
2813    ggml_metal_library_t lib = ctx->lib;
2814    ggml_metal_encoder_t enc = ctx->enc;
2815
2816    const bool use_fusion = ctx->use_fusion;
2817
2818    const int debug_fusion = ctx->debug_fusion;
2819
2820    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2821    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2822    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2823    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2824    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
2825    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
2826
2827    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2828    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2829
2830    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2831    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2832
2833    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2834    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2835    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
2836
2837    ggml_metal_kargs_bin args = {
2838        /*.ne00 =*/ ne00,
2839        /*.ne01 =*/ ne01,
2840        /*.ne02 =*/ ne02,
2841        /*.ne03 =*/ ne03,
2842        /*.nb00 =*/ nb00,
2843        /*.nb01 =*/ nb01,
2844        /*.nb02 =*/ nb02,
2845        /*.nb03 =*/ nb03,
2846        /*.ne10 =*/ ne10,
2847        /*.ne11 =*/ ne11,
2848        /*.ne12 =*/ ne12,
2849        /*.ne13 =*/ ne13,
2850        /*.nb10 =*/ nb10,
2851        /*.nb11 =*/ nb11,
2852        /*.nb12 =*/ nb12,
2853        /*.nb13 =*/ nb13,
2854        /*.ne0  =*/ ne0,
2855        /*.ne1  =*/ ne1,
2856        /*.ne2  =*/ ne2,
2857        /*.ne3  =*/ ne3,
2858        /*.nb0  =*/ nb0,
2859        /*.nb1  =*/ nb1,
2860        /*.nb2  =*/ nb2,
2861        /*.nb3  =*/ nb3,
2862        /*.offs =*/ 0,
2863        /*.o1   =*/ { bid_src1.offs },
2864    };
2865
2866    ggml_op fops[8];
2867
2868    int n_fuse = 1;
2869
2870    // c[0] = add(a,    b[0])
2871    // c[1] = add(c[0], b[1])
2872    // c[2] = add(c[1], b[2])
2873    // ...
2874    if (use_fusion) {
2875        fops[0] = GGML_OP_ADD;
2876        fops[1] = GGML_OP_ADD;
2877        fops[2] = GGML_OP_ADD;
2878        fops[3] = GGML_OP_ADD;
2879        fops[4] = GGML_OP_ADD;
2880        fops[5] = GGML_OP_ADD;
2881        fops[6] = GGML_OP_ADD;
2882        fops[7] = GGML_OP_ADD;
2883
2884        // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
2885        //       across splits. idx_end indicates the last node in the current split
2886        for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2887            if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2888                break;
2889            }
2890
2891            ggml_tensor * f0 = ctx->node(idx + n_fuse);
2892            ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2893
2894            if (f0 != f1->src[0]) {
2895                break;
2896            }
2897
2898            // b[0] === b[1] === ...
2899            if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
2900                break;
2901            }
2902
2903            // only fuse ops if src1 is in the same Metal buffer
2904            ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
2905            if (bid_fuse.metal != bid_src1.metal) {
2906                break;
2907            }
2908
2909            //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
2910
2911            args.o1[n_fuse + 1] = bid_fuse.offs;
2912        }
2913
2914        ++n_fuse;
2915
2916        if (debug_fusion > 1 && n_fuse > 1) {
2917            GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2918        }
2919    }
2920
2921    // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2922    bid_src1.offs = 0;
2923
2924    struct ggml_metal_pipeline_with_params pipeline;
2925
2926    pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2927
2928    if (n_fuse > 1) {
2929        bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2930
2931        for (int i = 1; i < n_fuse; ++i) {
2932            if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2933                ggml_metal_op_concurrency_reset(ctx);
2934
2935                break;
2936            }
2937        }
2938    }
2939
2940    if (pipeline.c4) {
2941        args.ne00 = ne00/4;
2942        args.ne10 = ne10/4;
2943        args.ne0  = ne0/4;
2944    }
2945
2946    ggml_metal_encoder_set_pipeline(enc, pipeline);
2947    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
2948    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
2949    ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
2950    ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
2951
2952    if (pipeline.cnt) {
2953        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
2954
2955        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
2956    } else {
2957        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2958
2959        int nth = 1;
2960
2961        while (2*nth < args.ne0 && nth < nth_max) {
2962            nth *= 2;
2963        }
2964
2965        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2966    }
2967
2968    return n_fuse;
2969}
2970
2971int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2972    ggml_tensor * op = ctx->node(idx);
2973
2974    ggml_metal_library_t lib = ctx->lib;
2975    ggml_metal_encoder_t enc = ctx->enc;
2976
2977    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2978    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2979    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
2980    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
2981
2982    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2983
2984    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2985    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
2986
2987    float eps;
2988    memcpy(&eps, op->op_params, sizeof(float));
2989
2990    ggml_metal_kargs_l2_norm args = {
2991        /*.ne00  =*/ ne00,
2992        /*.ne01  =*/ ne01,
2993        /*.ne02  =*/ ne02,
2994        /*.ne03  =*/ ne03,
2995        /*.nb00  =*/ nb00,
2996        /*.nb01  =*/ nb01,
2997        /*.nb02  =*/ nb02,
2998        /*.nb03  =*/ nb03,
2999        /*.ne0   =*/ ne0,
3000        /*.ne1   =*/ ne1,
3001        /*.ne2   =*/ ne2,
3002        /*.ne3   =*/ ne3,
3003        /*.nb0   =*/ nb0,
3004        /*.nb1   =*/ nb1,
3005        /*.nb2   =*/ nb2,
3006        /*.nb3   =*/ nb3,
3007        /*.eps   =*/ eps,
3008    };
3009
3010    auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
3011
3012    if (pipeline.c4) {
3013        args.ne00 = ne00/4;
3014        args.ne0  = ne0/4;
3015    }
3016
3017    int nth = 32; // SIMD width
3018
3019    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3020        nth *= 2;
3021    }
3022
3023    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3024
3025    const size_t smem = pipeline.smem;
3026
3027    ggml_metal_encoder_set_pipeline(enc, pipeline);
3028    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3029    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
3030    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
3031
3032    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3033
3034    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3035
3036    return 1;
3037}
3038
3039int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
3040    ggml_tensor * op = ctx->node(idx);
3041
3042    ggml_metal_library_t lib = ctx->lib;
3043    ggml_metal_encoder_t enc = ctx->enc;
3044
3045    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3046    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3047    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3048    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3049
3050    const int32_t ngrp = ((const int32_t *) op->op_params)[0];
3051
3052    float eps;
3053    memcpy(&eps, op->op_params + 1, sizeof(float));
3054
3055    ggml_metal_kargs_group_norm args = {
3056        /*.ne00 =*/ ne00,
3057        /*.ne01 =*/ ne01,
3058        /*.ne02 =*/ ne02,
3059        /*.nb00 =*/ nb00,
3060        /*.nb01 =*/ nb01,
3061        /*.nb02 =*/ nb02,
3062        /*.ngrp =*/ ngrp,
3063        /*.eps  =*/ eps,
3064    };
3065
3066    auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
3067
3068    int nth = 32; // SIMD width
3069    //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3070    //    nth *= 2;
3071    //}
3072
3073    //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3074    //nth = std::min(nth, ne00/4);
3075
3076    const size_t smem = pipeline.smem;
3077
3078    ggml_metal_encoder_set_pipeline(enc, pipeline);
3079    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3080    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3081    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3082
3083    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3084
3085    ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
3086
3087    return 1;
3088}
3089
3090int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
3091    ggml_tensor * op = ctx->node(idx);
3092
3093    ggml_metal_library_t lib = ctx->lib;
3094    ggml_metal_encoder_t enc = ctx->enc;
3095
3096    const bool use_fusion = ctx->use_fusion;
3097
3098    const int debug_fusion = ctx->debug_fusion;
3099
3100    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3101    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3102    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3103    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3104
3105    float eps;
3106    memcpy(&eps, op->op_params, sizeof(float));
3107
3108    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3109    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
3110
3111    ggml_metal_kargs_norm args = {
3112        /*.ne00   =*/ ne00,
3113        /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
3114        /*.nb1    =*/ nb1,
3115        /*.nb2    =*/ nb2,
3116        /*.nb3    =*/ nb3,
3117        /*.eps    =*/ eps,
3118        /*.nef1   =*/ { ne01 },
3119        /*.nef2   =*/ { ne02 },
3120        /*.nef3   =*/ { ne03 },
3121        /*.nbf1   =*/ { nb01 },
3122        /*.nbf2   =*/ { nb02 },
3123        /*.nbf3   =*/ { nb03 },
3124    };
3125
3126    ggml_op fops[8];
3127
3128    int n_fuse = 1;
3129
3130    ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
3131
3132    // d[0] = norm(a)
3133    // d[1] = mul(d[0], b)
3134    // d[2] = add(d[1], c)
3135    if (use_fusion) {
3136        fops[0] = op->op;
3137        fops[1] = GGML_OP_MUL;
3138        fops[2] = GGML_OP_ADD;
3139
3140        for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
3141            if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
3142                break;
3143            }
3144
3145            ggml_tensor * f0 = ctx->node(idx + n_fuse);
3146            ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
3147
3148            if (f0 != f1->src[0]) {
3149                break;
3150            }
3151
3152            if (f1->src[1]->ne[0] != op->ne[0]) {
3153                break;
3154            }
3155
3156            if (!ggml_is_contiguous_rows(f1->src[1])) {
3157                break;
3158            }
3159
3160            if (f1->type != GGML_TYPE_F32) {
3161                break;
3162            }
3163
3164            //ctx->fuse_cnt[f1->op]++;
3165
3166            bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
3167
3168            args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
3169            args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
3170            args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
3171
3172            args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
3173            args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
3174            args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
3175        }
3176
3177        ++n_fuse;
3178
3179        if (debug_fusion > 1 && n_fuse > 1) {
3180            if (n_fuse == 2) {
3181                GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
3182            }
3183            if (n_fuse == 3) {
3184                GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
3185            }
3186        }
3187    }
3188
3189    if (n_fuse > 1) {
3190        bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
3191
3192        for (int i = 1; i < n_fuse; ++i) {
3193            if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
3194                ggml_metal_op_concurrency_reset(ctx);
3195
3196                break;
3197            }
3198        }
3199    }
3200
3201    auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3202
3203    int nth = 32; // SIMD width
3204
3205    while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3206        nth *= 2;
3207    }
3208
3209    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3210    nth = std::min(nth, args.ne00_t);
3211
3212    const size_t smem = pipeline.smem;
3213
3214    ggml_metal_encoder_set_pipeline(enc, pipeline);
3215    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3216    ggml_metal_encoder_set_buffer  (enc, bid_src0,    1);
3217    ggml_metal_encoder_set_buffer  (enc, bid_fuse[0], 2);
3218    ggml_metal_encoder_set_buffer  (enc, bid_fuse[1], 3);
3219    ggml_metal_encoder_set_buffer  (enc, bid_dst,     4);
3220
3221    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3222
3223    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3224
3225    return n_fuse;
3226}
3227
3228int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
3229    ggml_tensor * op = ctx->node(idx);
3230
3231    ggml_metal_library_t lib = ctx->lib;
3232    ggml_metal_encoder_t enc = ctx->enc;
3233
3234    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3235    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3236    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3237    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3238    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3239    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3240
3241    // make sure we have one or more position id(ne10) per token(ne02)
3242    GGML_ASSERT(ne10 % ne02 == 0);
3243    GGML_ASSERT(ne10 >= ne02);
3244
3245    const int nth = std::min(1024, ne00);
3246
3247    const int n_past     = ((const int32_t *) op->op_params)[0];
3248    const int n_dims     = ((const int32_t *) op->op_params)[1];
3249  //const int mode       = ((const int32_t *) op->op_params)[2];
3250    // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
3251    const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
3252
3253    float freq_base;
3254    float freq_scale;
3255    float ext_factor;
3256    float attn_factor;
3257    float beta_fast;
3258    float beta_slow;
3259
3260    memcpy(&freq_base,   (const int32_t *) op->op_params +  5, sizeof(float));
3261    memcpy(&freq_scale,  (const int32_t *) op->op_params +  6, sizeof(float));
3262    memcpy(&ext_factor,  (const int32_t *) op->op_params +  7, sizeof(float));
3263    memcpy(&attn_factor, (const int32_t *) op->op_params +  8, sizeof(float));
3264    memcpy(&beta_fast,   (const int32_t *) op->op_params +  9, sizeof(float));
3265    memcpy(&beta_slow,   (const int32_t *) op->op_params + 10, sizeof(float));
3266
3267    // mrope
3268    const int sect_0 = ((const int32_t *) op->op_params)[11];
3269    const int sect_1 = ((const int32_t *) op->op_params)[12];
3270    const int sect_2 = ((const int32_t *) op->op_params)[13];
3271    const int sect_3 = ((const int32_t *) op->op_params)[14];
3272
3273    ggml_metal_kargs_rope args = {
3274        /*.ne00        =*/ ne00,
3275        /*.ne01        =*/ ne01,
3276        /*.ne02        =*/ ne02,
3277        /*.ne03        =*/ ne03,
3278        /*.nb00        =*/ nb00,
3279        /*.nb01        =*/ nb01,
3280        /*.nb02        =*/ nb02,
3281        /*.nb03        =*/ nb03,
3282        /*.ne0         =*/ ne0,
3283        /*.ne1         =*/ ne1,
3284        /*.ne2         =*/ ne2,
3285        /*.ne3         =*/ ne3,
3286        /*.nb0         =*/ nb0,
3287        /*.nb1         =*/ nb1,
3288        /*.nb2         =*/ nb2,
3289        /*.nb3         =*/ nb3,
3290        /*.n_past      =*/ n_past,
3291        /*.n_dims      =*/ n_dims,
3292        /*.n_ctx_orig  =*/ n_ctx_orig,
3293        /*.freq_base   =*/ freq_base,
3294        /*.freq_scale  =*/ freq_scale,
3295        /*.ext_factor  =*/ ext_factor,
3296        /*.attn_factor =*/ attn_factor,
3297        /*.beta_fast   =*/ beta_fast,
3298        /*.beta_slow   =*/ beta_slow,
3299        /* sect_0      =*/ sect_0,
3300        /* sect_1      =*/ sect_1,
3301        /* sect_2      =*/ sect_2,
3302        /* sect_3      =*/ sect_3,
3303        /* src2        =*/ op->src[2] != nullptr,
3304    };
3305
3306    auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3307
3308    ggml_metal_encoder_set_pipeline(enc, pipeline);
3309    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3310    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3311    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3312    if (op->src[2]) {
3313        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
3314    } else {
3315        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 3);
3316    }
3317    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         4);
3318
3319    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3320
3321    return 1;
3322}
3323
3324int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
3325    ggml_tensor * op = ctx->node(idx);
3326
3327    ggml_metal_library_t lib = ctx->lib;
3328    ggml_metal_encoder_t enc = ctx->enc;
3329
3330    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3331    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3332    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3333    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3334
3335    const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3336    const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3337    const int32_t p0 = ((const int32_t *)(op->op_params))[2];
3338    const int32_t p1 = ((const int32_t *)(op->op_params))[3];
3339    const int32_t d0 = ((const int32_t *)(op->op_params))[4];
3340    const int32_t d1 = ((const int32_t *)(op->op_params))[5];
3341
3342    const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
3343
3344    const int32_t N  = op->src[1]->ne[is_2D ? 3 : 2];
3345    const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
3346    const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
3347    const int32_t IW =         op->src[1]->ne[0];
3348
3349    const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
3350    const int32_t KW =         op->src[0]->ne[0];
3351
3352    const int32_t OH = is_2D ? op->ne[2] : 1;
3353    const int32_t OW =         op->ne[1];
3354
3355    const int32_t CHW = IC * KH * KW;
3356
3357    const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
3358    const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
3359
3360    ggml_metal_kargs_im2col args = {
3361        /*.ofs0 =*/ ofs0,
3362        /*.ofs1 =*/ ofs1,
3363        /*.IW   =*/ IW,
3364        /*.IH   =*/ IH,
3365        /*.CHW  =*/ CHW,
3366        /*.s0   =*/ s0,
3367        /*.s1   =*/ s1,
3368        /*.p0   =*/ p0,
3369        /*.p1   =*/ p1,
3370        /*.d0   =*/ d0,
3371        /*.d1   =*/ d1,
3372        /*.N    =*/ N,
3373        /*.KH   =*/ KH,
3374        /*.KW   =*/ KW,
3375        /*.KHW  =*/ KH * KW,
3376    };
3377
3378    auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3379
3380    GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3381
3382    const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3383
3384    ggml_metal_encoder_set_pipeline(enc, pipeline);
3385    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3386    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3387    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3388
3389    ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3390
3391    return 1;
3392}
3393
3394int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3395    ggml_tensor * op = ctx->node(idx);
3396
3397    ggml_metal_library_t lib = ctx->lib;
3398    ggml_metal_encoder_t enc = ctx->enc;
3399
3400    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3401    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3402    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3403    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3404    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3405    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3406
3407    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3408    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3409    GGML_ASSERT(op->type == GGML_TYPE_F32);
3410    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3411
3412    const int32_t s0 = ((const int32_t *) op->op_params)[0];
3413    const int32_t s1 = ((const int32_t *) op->op_params)[1];
3414    const int32_t p0 = ((const int32_t *) op->op_params)[2];
3415    const int32_t p1 = ((const int32_t *) op->op_params)[3];
3416    const int32_t d0 = ((const int32_t *) op->op_params)[4];
3417    const int32_t d1 = ((const int32_t *) op->op_params)[5];
3418
3419    ggml_metal_kargs_conv_2d args = {
3420        /*.nb00 =*/ nb00,
3421        /*.nb01 =*/ nb01,
3422        /*.nb02 =*/ nb02,
3423        /*.nb03 =*/ nb03,
3424        /*.nb10 =*/ nb10,
3425        /*.nb11 =*/ nb11,
3426        /*.nb12 =*/ nb12,
3427        /*.nb13 =*/ nb13,
3428        /*.nb0  =*/ nb0,
3429        /*.nb1  =*/ nb1,
3430        /*.nb2  =*/ nb2,
3431        /*.nb3  =*/ nb3,
3432        /*.IW   =*/ ne10,
3433        /*.IH   =*/ ne11,
3434        /*.KW   =*/ ne00,
3435        /*.KH   =*/ ne01,
3436        /*.IC   =*/ ne02,
3437        /*.OC   =*/ ne03,
3438        /*.OW   =*/ ne0,
3439        /*.OH   =*/ ne1,
3440        /*.N    =*/ ne3,
3441        /*.s0   =*/ s0,
3442        /*.s1   =*/ s1,
3443        /*.p0   =*/ p0,
3444        /*.p1   =*/ p1,
3445        /*.d0   =*/ d0,
3446        /*.d1   =*/ d1,
3447    };
3448
3449    auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3450
3451    int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3452    nth = std::min(nth, 256);
3453    nth = std::max(nth, 1);
3454
3455    const uint64_t n_out = ggml_nelements(op);
3456
3457    uint64_t tg = (n_out + nth - 1)/nth;
3458    tg = std::max<uint64_t>(tg, 1);
3459    tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3460
3461    ggml_metal_encoder_set_pipeline(enc, pipeline);
3462    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3463    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3464    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3465    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
3466
3467    ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3468
3469    return 1;
3470}
3471
3472int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3473    ggml_tensor * op = ctx->node(idx);
3474
3475    ggml_metal_library_t lib = ctx->lib;
3476    ggml_metal_encoder_t enc = ctx->enc;
3477
3478    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3479    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3480    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3481    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3482    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3483    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3484
3485    const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3486
3487    const int32_t IC = op->src[1]->ne[1];
3488    const int32_t IL = op->src[1]->ne[0];
3489
3490    const int32_t K  = op->src[0]->ne[0];
3491
3492    const int32_t OL = op->ne[0];
3493    const int32_t OC = op->ne[1];
3494
3495    ggml_metal_kargs_conv_transpose_1d args = {
3496        /*.IC  =*/ IC,
3497        /*.IL  =*/ IL,
3498        /*.K   =*/ K,
3499        /*.s0  =*/ s0,
3500        /*.nb0 =*/ nb0,
3501        /*.nb1 =*/ nb1,
3502    };
3503
3504    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3505
3506    ggml_metal_encoder_set_pipeline(enc, pipeline);
3507    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3508    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3509    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3510    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
3511
3512    ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3513
3514    return 1;
3515}
3516
3517int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3518    ggml_tensor * op = ctx->node(idx);
3519
3520    ggml_metal_library_t lib = ctx->lib;
3521    ggml_metal_encoder_t enc = ctx->enc;
3522
3523    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3524    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3525    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3526    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3527    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3528    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3529
3530    const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3531
3532    const int32_t IC = op->src[1]->ne[2];
3533    const int32_t IH = op->src[1]->ne[1];
3534    const int32_t IW = op->src[1]->ne[0];
3535
3536    const int32_t KH = op->src[0]->ne[1];
3537    const int32_t KW = op->src[0]->ne[0];
3538
3539    const int32_t OW = op->ne[0];
3540    const int32_t OH = op->ne[1];
3541    const int32_t OC = op->ne[2];
3542
3543    ggml_metal_kargs_conv_transpose_2d args = {
3544        /*.IC  =*/ IC,
3545        /*.IH  =*/ IH,
3546        /*.IW  =*/ IW,
3547        /*.KH  =*/ KH,
3548        /*.KW  =*/ KW,
3549        /*.OC  =*/ OC,
3550        /*.s0  =*/ s0,
3551        /*.nb0 =*/ nb0,
3552        /*.nb1 =*/ nb1,
3553        /*.nb2 =*/ nb2,
3554    };
3555
3556    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3557
3558    ggml_metal_encoder_set_pipeline(enc, pipeline);
3559    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3560    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3561    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3562    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
3563
3564    // Metal requires buffer size to be multiple of 16 bytes
3565    const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3566    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3567
3568    ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3569
3570    return 1;
3571}
3572
3573int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3574    ggml_tensor * op = ctx->node(idx);
3575
3576    ggml_metal_library_t lib = ctx->lib;
3577    ggml_metal_encoder_t enc = ctx->enc;
3578
3579    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3580    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3581    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3582    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3583
3584    const float sf0 = (float)ne0/op->src[0]->ne[0];
3585    const float sf1 = (float)ne1/op->src[0]->ne[1];
3586    const float sf2 = (float)ne2/op->src[0]->ne[2];
3587    const float sf3 = (float)ne3/op->src[0]->ne[3];
3588
3589    ggml_metal_kargs_upscale args = {
3590        /*.ne00 =*/ ne00,
3591        /*.ne01 =*/ ne01,
3592        /*.ne02 =*/ ne02,
3593        /*.ne03 =*/ ne03,
3594        /*.nb00 =*/ nb00,
3595        /*.nb01 =*/ nb01,
3596        /*.nb02 =*/ nb02,
3597        /*.nb03 =*/ nb03,
3598        /*.ne0 =*/ ne0,
3599        /*.ne1 =*/ ne1,
3600        /*.ne2 =*/ ne2,
3601        /*.ne3 =*/ ne3,
3602        /*.nb0 =*/ nb0,
3603        /*.nb1 =*/ nb1,
3604        /*.nb2 =*/ nb2,
3605        /*.nb3 =*/ nb3,
3606        /*.sf0 =*/ sf0,
3607        /*.sf1 =*/ sf1,
3608        /*.sf2 =*/ sf2,
3609        /*.sf3 =*/ sf3
3610    };
3611
3612    auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3613
3614    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3615
3616    ggml_metal_encoder_set_pipeline(enc, pipeline);
3617    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3618    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3619    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3620
3621    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3622
3623    return 1;
3624}
3625
3626int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
3627    ggml_tensor * op = ctx->node(idx);
3628
3629    ggml_metal_library_t lib = ctx->lib;
3630    ggml_metal_encoder_t enc = ctx->enc;
3631
3632    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3633    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3634    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3635    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3636
3637    ggml_metal_kargs_pad args = {
3638        /*.ne00 =*/ ne00,
3639        /*.ne01 =*/ ne01,
3640        /*.ne02 =*/ ne02,
3641        /*.ne03 =*/ ne03,
3642        /*.nb00 =*/ nb00,
3643        /*.nb01 =*/ nb01,
3644        /*.nb02 =*/ nb02,
3645        /*.nb03 =*/ nb03,
3646        /*.ne0  =*/ ne0,
3647        /*.ne1  =*/ ne1,
3648        /*.ne2  =*/ ne2,
3649        /*.ne3  =*/ ne3,
3650        /*.nb0  =*/ nb0,
3651        /*.nb1  =*/ nb1,
3652        /*.nb2  =*/ nb2,
3653        /*.nb3  =*/ nb3
3654    };
3655
3656    auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3657
3658    const int nth = std::min(1024, ne0);
3659
3660    ggml_metal_encoder_set_pipeline(enc, pipeline);
3661    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3662    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3663    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3664
3665    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3666
3667    return 1;
3668}
3669
3670int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
3671    ggml_tensor * op = ctx->node(idx);
3672
3673    ggml_metal_library_t lib = ctx->lib;
3674    ggml_metal_encoder_t enc = ctx->enc;
3675
3676    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3677    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3678    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3679    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3680
3681    ggml_metal_kargs_pad_reflect_1d args = {
3682        /*.ne00 =*/ ne00,
3683        /*.ne01 =*/ ne01,
3684        /*.ne02 =*/ ne02,
3685        /*.ne03 =*/ ne03,
3686        /*.nb00 =*/ nb00,
3687        /*.nb01 =*/ nb01,
3688        /*.nb02 =*/ nb02,
3689        /*.nb03 =*/ nb03,
3690        /*.ne0  =*/ ne0,
3691        /*.ne1  =*/ ne1,
3692        /*.ne2  =*/ ne2,
3693        /*.ne3  =*/ ne3,
3694        /*.nb0  =*/ nb0,
3695        /*.nb1  =*/ nb1,
3696        /*.nb2  =*/ nb2,
3697        /*.nb3  =*/ nb3,
3698        /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
3699        /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3700    };
3701
3702    auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3703
3704    const int nth = std::min(1024, ne0);
3705
3706    ggml_metal_encoder_set_pipeline(enc, pipeline);
3707    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3708    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3709    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3710
3711    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3712
3713    return 1;
3714}
3715
3716int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
3717    ggml_tensor * op = ctx->node(idx);
3718
3719    ggml_metal_library_t lib = ctx->lib;
3720    ggml_metal_encoder_t enc = ctx->enc;
3721
3722    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3723    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3724
3725    float start;
3726    float step;
3727
3728    memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
3729    memcpy(&step,  ((const int32_t *) op->op_params) + 2, sizeof(float));
3730
3731    ggml_metal_kargs_arange args = {
3732        /*.ne0   =*/ ne0,
3733        /*.start =*/ start,
3734        /*.step  =*/ step
3735    };
3736
3737    const int nth = std::min(1024, ne0);
3738
3739    auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
3740
3741    ggml_metal_encoder_set_pipeline(enc, pipeline);
3742    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3743    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op), 1);
3744
3745    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
3746
3747    return 1;
3748}
3749
3750int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3751    ggml_tensor * op = ctx->node(idx);
3752
3753    ggml_metal_library_t lib = ctx->lib;
3754    ggml_metal_encoder_t enc = ctx->enc;
3755
3756    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3757    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3758    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3759    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3760
3761    const int dim        = op->op_params[0];
3762    const int max_period = op->op_params[1];
3763
3764    ggml_metal_kargs_timestep_embedding args = {
3765        /*.nb1 =*/ nb1,
3766        /*.dim =*/ dim,
3767        /*.max_period =*/ max_period,
3768    };
3769
3770    auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3771
3772    const int nth = std::max(1, std::min(1024, dim/2));
3773
3774    ggml_metal_encoder_set_pipeline(enc, pipeline);
3775    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3776    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3777    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3778
3779    ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
3780
3781    return 1;
3782}
3783
3784int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3785    ggml_tensor * op = ctx->node(idx);
3786
3787    ggml_metal_library_t lib = ctx->lib;
3788    ggml_metal_encoder_t enc = ctx->enc;
3789
3790    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3791    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3792    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3793    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3794
3795    ggml_metal_kargs_argmax args = {
3796        /*.ne00 = */ ne00,
3797        /*.nb01 = */ nb01,
3798    };
3799
3800    auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3801
3802    const int64_t nrows = ggml_nrows(op->src[0]);
3803
3804    int nth = 32; // SIMD width
3805    while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3806        nth *= 2;
3807    }
3808
3809    const size_t smem = pipeline.smem;
3810
3811    ggml_metal_encoder_set_pipeline(enc, pipeline);
3812    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3813    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3814    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
3815
3816    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3817
3818    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3819
3820    return 1;
3821}
3822
3823int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3824    ggml_tensor * op = ctx->node(idx);
3825
3826    ggml_metal_library_t lib = ctx->lib;
3827    ggml_metal_encoder_t enc = ctx->enc;
3828
3829    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3830
3831    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3832    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3833    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3834    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3835
3836    auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3837
3838    // bitonic sort requires the number of elements to be power of 2
3839    int nth = 1;
3840    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3841        nth *= 2;
3842    }
3843
3844    const int npr = (ne00 + nth - 1)/nth;
3845
3846    // Metal kernels require the buffer size to be multiple of 16 bytes
3847    // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3848    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3849
3850    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3851    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
3852
3853    ggml_metal_buffer_id bid_tmp = bid_dst;
3854    bid_tmp.offs += ggml_nbytes(op);
3855
3856    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3857        std::swap(bid_dst, bid_tmp);
3858    }
3859
3860    ggml_metal_kargs_argsort args = {
3861        /*.ne00  =*/ ne00,
3862        /*.ne01  =*/ ne01,
3863        /*.ne02  =*/ ne02,
3864        /*.ne03  =*/ ne03,
3865        /*.nb00  =*/ nb00,
3866        /*.nb01  =*/ nb01,
3867        /*.nb02  =*/ nb02,
3868        /*.nb03  =*/ nb03,
3869        /*.ne0   =*/ ne0,
3870        /*.ne1   =*/ ne1,
3871        /*.ne2   =*/ ne2,
3872        /*.ne3   =*/ ne3,
3873        /*.top_k =*/ nth,
3874    };
3875
3876    ggml_metal_encoder_set_pipeline(enc, pipeline);
3877    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3878    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
3879    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
3880
3881    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3882
3883    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3884
3885    auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3886
3887    int len = nth;
3888
3889    while (len < ne00) {
3890        ggml_metal_op_concurrency_reset(ctx);
3891
3892        ggml_metal_kargs_argsort_merge args_merge = {
3893            /*.ne00  =*/ ne00,
3894            /*.ne01  =*/ ne01,
3895            /*.ne02  =*/ ne02,
3896            /*.ne03  =*/ ne03,
3897            /*.nb00  =*/ nb00,
3898            /*.nb01  =*/ nb01,
3899            /*.nb02  =*/ nb02,
3900            /*.nb03  =*/ nb03,
3901            /*.ne0   =*/ ne0,
3902            /*.ne1   =*/ ne1,
3903            /*.ne2   =*/ ne2,
3904            /*.ne3   =*/ ne3,
3905            /*.top_k =*/ ne00,
3906            /*.len   =*/ len,
3907        };
3908
3909        // merges per row
3910        const int nm = (ne00 + 2*len - 1) / (2*len);
3911
3912        const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3913
3914        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3915        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);
3916        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
3917        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
3918        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);
3919
3920        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3921
3922        std::swap(bid_dst, bid_tmp);
3923
3924        len <<= 1;
3925    }
3926
3927    return 1;
3928}
3929
3930int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3931    ggml_tensor * op = ctx->node(idx);
3932
3933    ggml_metal_library_t lib = ctx->lib;
3934    ggml_metal_encoder_t enc = ctx->enc;
3935
3936    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3937
3938    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3939    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3940    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
3941    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
3942
3943    auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
3944
3945    // bitonic sort requires the number of elements to be power of 2
3946    int nth = 1;
3947    while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3948        nth *= 2;
3949    }
3950
3951    // blocks per row
3952    const int npr = (ne00 + nth - 1)/nth;
3953
3954    const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3955
3956    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3957    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
3958
3959    ggml_metal_buffer_id bid_tmp = bid_dst;
3960    bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3961
3962    if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3963        std::swap(bid_dst, bid_tmp);
3964    }
3965
3966    const int top_k = ne0;
3967
3968    ggml_metal_kargs_argsort args = {
3969        /*.ne00  =*/ ne00,
3970        /*.ne01  =*/ ne01,
3971        /*.ne02  =*/ ne02,
3972        /*.ne03  =*/ ne03,
3973        /*.nb00  =*/ nb00,
3974        /*.nb01  =*/ nb01,
3975        /*.nb02  =*/ nb02,
3976        /*.nb03  =*/ nb03,
3977        /*.ne0   =*/ ne0,
3978        /*.ne1   =*/ ne1,
3979        /*.ne2   =*/ ne2,
3980        /*.ne3   =*/ ne3,
3981        /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3982    };
3983
3984    if (npr > 1) {
3985        args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3986    }
3987
3988    ggml_metal_encoder_set_pipeline(enc, pipeline);
3989    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
3990    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
3991    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
3992
3993    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3994
3995    ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3996
3997    auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3998
3999    int len = args.top_k;
4000
4001    while (len < args.ne0) {
4002        ggml_metal_op_concurrency_reset(ctx);
4003
4004        // merges per row
4005        const int nm = (args.ne0 + 2*len - 1) / (2*len);
4006
4007        const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
4008
4009        ggml_metal_kargs_argsort_merge args_merge = {
4010            /*.ne00  =*/ ne00,
4011            /*.ne01  =*/ ne01,
4012            /*.ne02  =*/ ne02,
4013            /*.ne03  =*/ ne03,
4014            /*.nb00  =*/ nb00,
4015            /*.nb01  =*/ nb01,
4016            /*.nb02  =*/ nb02,
4017            /*.nb03  =*/ nb03,
4018            /*.ne0   =*/ args.ne0,
4019            /*.ne1   =*/ ne1,
4020            /*.ne2   =*/ ne2,
4021            /*.ne3   =*/ ne3,
4022            /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
4023            /*.len   =*/ len,
4024        };
4025
4026        ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
4027        ggml_metal_encoder_set_bytes   (enc, &args_merge, sizeof(args_merge), 0);
4028        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
4029        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
4030        ggml_metal_encoder_set_buffer  (enc, bid_tmp,  3);
4031
4032        ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
4033
4034        std::swap(bid_dst, bid_tmp);
4035
4036        len <<= 1;
4037    }
4038
4039    return 1;
4040}
4041
4042int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
4043    ggml_tensor * op = ctx->node(idx);
4044
4045    ggml_metal_library_t lib = ctx->lib;
4046    ggml_metal_encoder_t enc = ctx->enc;
4047
4048    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4049    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4050    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
4051    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
4052
4053    ggml_metal_kargs_tri args = {
4054        /*.ne00  =*/ ne00,
4055        /*.ne01  =*/ ne01,
4056        /*.ne02  =*/ ne02,
4057        /*.ne03  =*/ ne03,
4058        /*.nb00  =*/ nb00,
4059        /*.nb01  =*/ nb01,
4060        /*.nb02  =*/ nb02,
4061        /*.nb03  =*/ nb03,
4062        /*.ne0   =*/ ne0,
4063        /*.ne1   =*/ ne1,
4064        /*.ne2   =*/ ne2,
4065        /*.ne3   =*/ ne3,
4066        /*.nb0   =*/ nb0,
4067        /*.nb1   =*/ nb1,
4068        /*.nb2   =*/ nb2,
4069        /*.nb3   =*/ nb3,
4070    };
4071
4072    auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
4073
4074    int nth = 32; // SIMD width
4075
4076    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4077        nth *= 2;
4078    }
4079
4080    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4081    nth = std::min(nth, ne00);
4082
4083    ggml_metal_encoder_set_pipeline(enc, pipeline);
4084    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
4085    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4086    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
4087
4088    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4089
4090    return 1;
4091}
4092
4093int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4094    ggml_tensor * op = ctx->node(idx);
4095
4096    ggml_metal_library_t lib = ctx->lib;
4097    ggml_metal_encoder_t enc = ctx->enc;
4098
4099    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4100    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4101    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
4102    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
4103
4104    auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4105
4106    const int64_t np = ggml_nelements(op->src[0]);
4107    ggml_metal_kargs_opt_step_adamw args = {
4108        /*.np =*/ np,
4109    };
4110
4111    int ida = 0;
4112
4113    ggml_metal_encoder_set_pipeline(enc, pipeline);
4114    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);
4115    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4116    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4117    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4118    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4119    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4120
4121    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4122    const int64_t n = (np + nth - 1) / nth;
4123
4124    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4125
4126    return 1;
4127}
4128
4129int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4130    ggml_tensor * op = ctx->node(idx);
4131
4132    ggml_metal_library_t lib = ctx->lib;
4133    ggml_metal_encoder_t enc = ctx->enc;
4134
4135    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4136    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4137    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
4138    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
4139
4140    auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4141
4142    const int64_t np = ggml_nelements(op->src[0]);
4143    ggml_metal_kargs_opt_step_sgd args = {
4144        /*.np =*/ np,
4145    };
4146
4147    int ida = 0;
4148
4149    ggml_metal_encoder_set_pipeline(enc, pipeline);
4150    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++);
4151    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4152    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4153    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4154
4155    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4156    const int64_t n = (np + nth - 1) / nth;
4157
4158    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4159
4160    return 1;
4161}
4162
4163int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4164    ggml_tensor * op = ctx->node(idx);
4165
4166    ggml_metal_library_t lib = ctx->lib;
4167    ggml_metal_encoder_t enc = ctx->enc;
4168
4169    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
4170    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4171    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4172
4173    {
4174        ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4175
4176        auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4177
4178        ggml_metal_encoder_set_pipeline(enc, pipeline);
4179        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4180        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4181
4182        ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4183    }
4184
4185    ggml_metal_op_concurrency_reset(ctx);
4186
4187    {
4188        ggml_metal_kargs_count_equal args = {
4189            /*.ne00 =*/ ne00,
4190            /*.ne01 =*/ ne01,
4191            /*.ne02 =*/ ne02,
4192            /*.ne03 =*/ ne03,
4193            /*.nb00 =*/ nb00,
4194            /*.nb01 =*/ nb01,
4195            /*.nb02 =*/ nb02,
4196            /*.nb03 =*/ nb03,
4197            /*.nb10 =*/ nb10,
4198            /*.nb11 =*/ nb11,
4199            /*.nb12 =*/ nb12,
4200            /*.nb13 =*/ nb13,
4201        };
4202
4203        auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4204
4205        const size_t smem = pipeline.smem;
4206
4207        const int nth = 32*pipeline.nsg;
4208
4209        GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4210
4211        ggml_metal_encoder_set_pipeline(enc, pipeline);
4212        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4213        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4214        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4215        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4216
4217        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4218        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4219    }
4220
4221    return 1;
4222}