diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp | 4222 |
1 files changed, 4222 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp new file mode 100644 index 0000000..7db95d1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -0,0 +1,4222 @@ +#include "ggml-metal-ops.h" + +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-metal-impl.h" +#include "ggml-metal-common.h" +#include "ggml-metal-device.h" + +#include <cassert> +#include <algorithm> +#include <limits> +#include <cmath> + +static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { + if (!t) { + return { nullptr, 0 }; + } + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context; + + return ggml_metal_buffer_get_id(ctx, t); +} + +struct ggml_metal_op { + ggml_metal_op( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + this->dev = dev; + this->lib = ggml_metal_device_get_library(dev); + this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency); + this->mem_ranges = ggml_mem_ranges_init(debug_graph); + this->idx_start = idx_start; + this->idx_end = idx_end; + this->use_fusion = use_fusion; + this->use_concurrency = use_concurrency; + this->use_capture = use_capture; + this->debug_graph = debug_graph; + this->debug_fusion = debug_fusion; + this->gf = gf; + + idxs.reserve(gf->n_nodes); + + // filter empty nodes + // TODO: this can be removed when the allocator starts filtering them earlier + // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830 + for (int i = idx_start; i < idx_end; i++) { + if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) { + idxs.push_back(i); + } + } + } + + ~ggml_metal_op() { + ggml_metal_encoder_end_encoding(this->enc); + ggml_metal_encoder_free(this->enc); + ggml_mem_ranges_free(this->mem_ranges); + } + + int n_nodes() const { + return idxs.size(); + } + + ggml_tensor * node(int i) const { + assert(i >= 0 && i < (int) idxs.size()); + return ggml_graph_node(gf, idxs[i]); + } + + bool can_fuse(int i0, const ggml_op * ops, int n_ops) const { + assert(use_fusion); + assert(i0 >= 0 && i0 < n_nodes()); + + if (i0 + n_ops > n_nodes()) { + return false; + } + + return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops); + } + + ggml_metal_device_t dev; + ggml_metal_library_t lib; + ggml_metal_encoder_t enc; + ggml_mem_ranges_t mem_ranges; + + bool use_fusion; + bool use_concurrency; + bool use_capture; + + int debug_graph; + int debug_fusion; + +private: + ggml_cgraph * gf; + + int idx_start; + int idx_end; + + // non-empty node indices + std::vector<int> idxs; +}; + +ggml_metal_op_t ggml_metal_op_init( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + ggml_metal_op_t res = new ggml_metal_op( + dev, + cmd_buf, + gf, + idx_start, + idx_end, + use_fusion, + use_concurrency, + use_capture, + debug_graph, + debug_fusion); + + return res; +} + +void ggml_metal_op_free(ggml_metal_op_t ctx) { + delete ctx; +} + +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) { + return ctx->n_nodes(); +} + +static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { + if (!ctx->mem_ranges) { + return true; + } + + ggml_metal_encoder_memory_barrier(ctx->enc); + + ggml_mem_ranges_reset(ctx->mem_ranges); + + return true; +} + +static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return false; + } + + return ggml_mem_ranges_check(ctx->mem_ranges, node); +} + +static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return true; + } + + return ggml_mem_ranges_add(ctx->mem_ranges, node); +} + +static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { + struct ggml_tensor * node = ctx->node(idx); + + //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); + + if (ggml_is_empty(node)) { + return 1; + } + + switch (node->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)"); + } + } return 1; + default: + { + } break; + } + + if (!ggml_metal_device_supports_op(ctx->dev, node)) { + GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node)); + GGML_ABORT("unsupported op"); + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + + int n_fuse = 1; + + // check if the current node can run concurrently with other nodes before it + // the condition is that: + // - the current node cannot write to any previous src or dst ranges + // - the current node cannot read from any previous dst ranges + // + // if the condition is not satisfied, we put a memory barrier and clear all ranges + // otherwise, we add the new ranges to the encoding context and process the node concurrently + // + { + const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node); + + if (!is_concurrent) { + ggml_metal_op_concurrency_reset(ctx); + } + + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : ""); + } + if (ctx->debug_graph > 1) { + GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); + GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); + GGML_TENSOR_LOCALS( int64_t, ne, node, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); + + if (node->src[0]) { + GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(node->src[0]), node->src[0]->name); + } + if (node->src[1]) { + GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(node->src[1]), node->src[1]->name); + } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } + if (node) { + GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + node->name); + } + } + } + + switch (node->op) { + case GGML_OP_CONCAT: + { + n_fuse = ggml_metal_op_concat(ctx, idx); + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + n_fuse = ggml_metal_op_bin(ctx, idx); + } break; + case GGML_OP_ADD_ID: + { + n_fuse = ggml_metal_op_add_id(ctx, idx); + } break; + case GGML_OP_REPEAT: + { + n_fuse = ggml_metal_op_repeat(ctx, idx); + } break; + case GGML_OP_ACC: + { + n_fuse = ggml_metal_op_acc(ctx, idx); + } break; + case GGML_OP_SCALE: + case GGML_OP_FILL: + case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + case GGML_OP_UNARY: + { + n_fuse = ggml_metal_op_unary(ctx, idx); + } break; + case GGML_OP_GLU: + { + n_fuse = ggml_metal_op_glu(ctx, idx); + } break; + case GGML_OP_SUM: + { + n_fuse = ggml_metal_op_sum(ctx, idx); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + { + n_fuse = ggml_metal_op_sum_rows(ctx, idx); + } break; + case GGML_OP_CUMSUM: + { + n_fuse = ggml_metal_op_cumsum(ctx, idx); + } break; + case GGML_OP_SOFT_MAX: + { + n_fuse = ggml_metal_op_soft_max(ctx, idx); + } break; + case GGML_OP_SSM_CONV: + { + n_fuse = ggml_metal_op_ssm_conv(ctx, idx); + } break; + case GGML_OP_SSM_SCAN: + { + n_fuse = ggml_metal_op_ssm_scan(ctx, idx); + } break; + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + { + n_fuse = ggml_metal_op_rwkv(ctx, idx); + } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; + case GGML_OP_MUL_MAT: + { + n_fuse = ggml_metal_op_mul_mat(ctx, idx); + } break; + case GGML_OP_MUL_MAT_ID: + { + n_fuse = ggml_metal_op_mul_mat_id(ctx, idx); + } break; + case GGML_OP_GET_ROWS: + { + n_fuse = ggml_metal_op_get_rows(ctx, idx); + } break; + case GGML_OP_SET_ROWS: + { + n_fuse = ggml_metal_op_set_rows(ctx, idx); + } break; + case GGML_OP_DIAG: + { + n_fuse = ggml_metal_op_diag(ctx, idx); + } break; + case GGML_OP_L2_NORM: + { + n_fuse = ggml_metal_op_l2_norm(ctx, idx); + } break; + case GGML_OP_GROUP_NORM: + { + n_fuse = ggml_metal_op_group_norm(ctx, idx); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + { + n_fuse = ggml_metal_op_norm(ctx, idx); + } break; + case GGML_OP_ROPE: + { + n_fuse = ggml_metal_op_rope(ctx, idx); + } break; + case GGML_OP_IM2COL: + { + n_fuse = ggml_metal_op_im2col(ctx, idx); + } break; + case GGML_OP_CONV_2D: + { + n_fuse = ggml_metal_op_conv_2d(ctx, idx); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); + } break; + case GGML_OP_UPSCALE: + { + n_fuse = ggml_metal_op_upscale(ctx, idx); + } break; + case GGML_OP_PAD: + { + n_fuse = ggml_metal_op_pad(ctx, idx); + } break; + case GGML_OP_PAD_REFLECT_1D: + { + n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); + } break; + case GGML_OP_ARANGE: + { + n_fuse = ggml_metal_op_arange(ctx, idx); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + n_fuse = ggml_metal_op_timestep_embedding(ctx, idx); + } break; + case GGML_OP_ARGSORT: + { + n_fuse = ggml_metal_op_argsort(ctx, idx); + } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; + case GGML_OP_TRI: + { + n_fuse = ggml_metal_op_tri(ctx, idx); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + n_fuse = ggml_metal_op_cpy(ctx, idx); + } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; + case GGML_OP_POOL_2D: + { + n_fuse = ggml_metal_op_pool_2d(ctx, idx); + } break; + case GGML_OP_ARGMAX: + { + n_fuse = ggml_metal_op_argmax(ctx, idx); + } break; + case GGML_OP_OPT_STEP_ADAMW: + { + n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); + } break; + case GGML_OP_OPT_STEP_SGD: + { + n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); + } break; + case GGML_OP_COUNT_EQUAL: + { + n_fuse = ggml_metal_op_count_equal(ctx, idx); + } break; + default: + { + GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); + GGML_ABORT("fatal error"); + } + } + + if (ctx->debug_graph > 0) { + if (n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse); + } + } + + // update the mem ranges in the encoding context + for (int i = 0; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + } + } + + return n_fuse; +} + +int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx))); + } + + int res = ggml_metal_op_encode_impl(ctx, idx); + if (idx + res > ctx->n_nodes()) { + GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", + "https://github.com/ggml-org/llama.cpp/pull/14849"); + } + + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_pop(ctx->enc); + } + + return res; +} + +int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t dim = ((const int32_t *) op->op_params)[0]; + + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + /*.o1 =*/ { 0 }, + }; + + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_unary args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.slope =*/ 0.0, + /*.scale =*/ 0.0, + /*.bias =*/ 0.0, + /*.val =*/ 0.0, + /*.min =*/ 0.0, + /*.max =*/ 0.0, + }; + + if (op->op == GGML_OP_LEAKY_RELU) { + args.slope = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_SCALE) { + args.scale = ggml_get_op_params_f32(op, 0); + args.bias = ggml_get_op_params_f32(op, 1); + } + + if (op->op == GGML_OP_FILL) { + args.val = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_CLAMP) { + args.min = ggml_get_op_params_f32(op, 0); + args.max = ggml_get_op_params_f32(op, 1); + } + + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nth = MIN(args.ne00, nth_max); + + const int nk0 = (args.ne00 + nth - 1)/nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); + } + + return 1; +} + +int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + if (op->src[1]) { + GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); + } + + auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + + const int32_t swp = ggml_get_op_params_i32(op, 1); + const float alpha = ggml_get_op_params_f32(op, 2); + const float limit = ggml_get_op_params_f32(op, 3); + + const int32_t i00 = swp ? ne0 : 0; + const int32_t i10 = swp ? 0 : ne0; + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.ne10 =*/ op->src[1] ? ne10 : ne00, + /*.nb11 =*/ op->src[1] ? nb11 : nb01, + /*.ne0 =*/ ne0, + /*.nb1 =*/ nb1, + /*.i00 =*/ op->src[1] ? 0 : i00, + /*.i10 =*/ op->src[1] ? 0 : i10, + /*.alpha=*/ alpha, + /*.limit=*/ limit + }; + + const int64_t nrows = ggml_nrows(op->src[0]); + + const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); + + ggml_metal_kargs_sum args = { + /*.np =*/ n, + }; + + auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + + int nth = 32; // SIMD width + + while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int) n); + + const int nsg = (nth + 31) / 32; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { + nth *= 2; + } + + GGML_ASSERT(ne00 <= nth*nth); + + const int64_t net0 = (ne00 + nth - 1) / nth; + const int64_t net1 = ne01; + const int64_t net2 = ne02; + const int64_t net3 = ne03; + + const uint64_t nbt0 = sizeof(float); + const uint64_t nbt1 = net0*nbt0; + const uint64_t nbt2 = net1*nbt1; + const uint64_t nbt3 = net2*nbt2; + + const size_t smem = GGML_PAD(32*sizeof(float), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += ggml_nbytes(op); + + { + ggml_metal_kargs_cumsum_blk args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + /*.outb =*/ ne00 > nth, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_blk); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); + ggml_metal_encoder_set_buffer (enc, bid_dst, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); + } + + if (ne00 > nth) { + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_kargs_cumsum_blk args = { + /*.ne00 =*/ net0, + /*.ne01 =*/ net1, + /*.ne02 =*/ net2, + /*.ne03 =*/ net3, + /*.nb00 =*/ nbt0, + /*.nb01 =*/ nbt1, + /*.nb02 =*/ nbt2, + /*.nb03 =*/ nbt3, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + /*.outb =*/ false, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_blk); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1); + } + + ggml_metal_op_concurrency_reset(ctx); + + { + auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + + ggml_metal_kargs_cumsum_add args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_add); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); + } + } + + return 1; +} + +int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + + ggml_metal_kargs_get_rows args = { + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + + const int32_t nk0 = ne0/ggml_blck_size(op->type); + + int nth = 32; // SIMD width + + while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + int nrptg = 1; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_set_rows args = { + /*.nk0 =*/ nk0, + /*.ne01 =*/ ne01, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_diag args = { + /*.ne00 =*/ne00, + /*.ne01 =*/ne01, + /*.ne02 =*/ne02, + /*.ne03 =*/ne03, + /*.nb00 =*/nb00, + /*.nb01 =*/nb01, + /*.nb02 =*/nb02, + /*.nb03 =*/nb03, + /*.ne0 =*/ne0, + /*.ne1 =*/ne1, + /*.ne2 =*/ne2, + /*.ne3 =*/ne3, + /*.nb0 =*/nb0, + /*.nb1 =*/nb1, + /*.nb2 =*/nb2, + /*.nb3 =*/nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1); + + return 1; +} + +int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + + auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + + int nth = 32; // SIMD width + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + if (op->src[2]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + const bool use_batched = (ne1 > 1); + + if (use_batched) { + // Determine the smallest power of 2 that's >= ne1, but <= 256 + int BATCH_SIZE; + if (ne1 > 128) BATCH_SIZE = 256; + else if (ne1 > 64 ) BATCH_SIZE = 128; + else if (ne1 > 32 ) BATCH_SIZE = 64; + else if (ne1 > 16 ) BATCH_SIZE = 32; + else if (ne1 > 8 ) BATCH_SIZE = 16; + else if (ne1 > 4 ) BATCH_SIZE = 8; + else BATCH_SIZE = 2; + + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } + + return 1; +} + +int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne); + GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb); + GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne); + GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb); + GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne); + GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const ggml_tensor * src3 = op->src[3]; + const ggml_tensor * src4 = op->src[4]; + const ggml_tensor * src5 = op->src[5]; + const ggml_tensor * src6 = op->src[6]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + GGML_ASSERT(src6); + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; + + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, + /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, + /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, + /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, + }; + + auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); + + return 1; +} + +int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; + const int64_t T = op->src[0]->ne[2]; + const int64_t C = op->ne[0]; + const int64_t H = op->src[0]->ne[1]; + + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + + ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + + return 1; +} + +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); + + return 1; +} + +int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); + + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); + } + + int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min<int>(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + const int64_t OW = op->ne[0]; + + const int64_t np = ggml_nelements(op); + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + +int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = op->src[0]->ne[1]; + const int64_t IW = op->src[0]->ne[0]; + + const int64_t N = op->ne[3]; + const int64_t OC = op->ne[2]; + const int64_t OH = op->ne[1]; + const int64_t OW = op->ne[0]; + + const int64_t np = N * OC * OH * OW; + + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const int16_t r2 = ne12/ne02; + const int16_t r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + const int ne11_mm_min = 8; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) && + ( + ( + ( + op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q8_0 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + + // num threads along row per simdgroup + int16_t nxpsg = 0; + if (ne00 % 256 == 0 && ne11 < 3) { + nxpsg = 16; + } else if (ne00 % 128 == 0) { + nxpsg = 8; + } else { + nxpsg = 4; + } + + const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int16_t r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + default: + GGML_ABORT("unsupported ne11"); + }; + + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1); + } else if ( + !ggml_is_transposed(op->src[0]) && + !ggml_is_transposed(op->src[1]) && + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} + + auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + const size_t smem = pipeline.smem; + + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } + } + + return 1; +} + +size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + + return ggml_type_size(GGML_TYPE_I32)*ne02; +} + +size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + const int64_t ne21 = op->src[2]->ne[1]; // n_token + + return ggml_type_size(GGML_TYPE_I32)*ne02*ne21; +} + +int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // src2 = ids + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(op->src[0])); + GGML_ASSERT(!ggml_is_transposed(op->src[1])); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const uint32_t r2 = 1; + const uint32_t r3 = 1; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; + + if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} + + // extra buffers for intermediate id mapping + ggml_metal_buffer_id bid_tpe = bid_dst; + bid_tpe.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_ids = bid_tpe; + bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op); + + { + ggml_metal_kargs_mul_mm_id_map0 args = { + ne02, + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + ne21, // n_tokens + ne20, // n_expert_used + nb21, + }; + + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + + const size_t smem = pipeline.smem; + + GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src2, 1); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 2); + ggml_metal_encoder_set_buffer (enc, bid_ids, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1); + } + + // this barrier is always needed because the next kernel has to wait for the id maps to be computed + ggml_metal_op_concurrency_reset(ctx); + + { + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, // n_expert_used (bcast) + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, // n_expert_used + /*.ne21 =*/ ne21, // n_tokens + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 3); + ggml_metal_encoder_set_buffer (enc, bid_ids, 4); + ggml_metal_encoder_set_buffer (enc, bid_dst, 5); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); + } + } else { + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + const size_t smem = pipeline.smem; + + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, + }; + + if (ggml_is_quantized(op->src[0]->type)) { + GGML_ASSERT(ne00 >= nsg*nr0); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, bid_src0, 1); + ggml_metal_encoder_set_buffer(enc, bid_src1, 2); + ggml_metal_encoder_set_buffer(enc, bid_dst, 3); + ggml_metal_encoder_set_buffer(enc, bid_src2, 4); + + const int64_t _ne1 = 1; + const int64_t ne123 = ne20*ne21; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } + } + + return 1; +} + +int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_kargs_add_id args = { + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb11 =*/ nb11, + /*.nb21 =*/ nb21, + }; + + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1); + + return 1; +} + +bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t ne00 = op->src[0]->ne[0]; // head size + const int64_t ne01 = op->src[0]->ne[1]; // batch size + + // use vec kernel if the batch size is small and if the head size is supported + return (ne01 < 20) && (ne00 % 32 == 0); +} + +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + // note: the non-vec kernel requires more extra memory, so always reserve for it + GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG); + + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (false) { + // note: always reserve the padding space to avoid graph reallocations + //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + const bool has_kvpad = true; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + const bool has_kvpad = true; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + // note: always reserve the blk buffer to avoid graph reallocations + //if (is_vec) { + // return res; + //} + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + // note: always reserve the temp buffer to avoid graph reallocations + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (true) { + const int64_t nwg = 32; + const int64_t ne01_max = std::min(ne01, 32); + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; +} + +int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS( int32_t, nb, op, nb); + + GGML_ASSERT(ne00 % 4 == 0); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == op->src[2]->type); + + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); + + GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); + + float scale; + float max_bias; + float logit_softcap; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const bool has_mask = op->src[3] != NULL; + const bool has_sinks = op->src[4] != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(ne01 < 65536); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); + + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { + // half8x8 kernel + const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } + + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; + + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) + + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > props_dev->max_theadgroup_memory_size) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} + + // simdgroups per threadgroup (a.k.a. warps) + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; + + const size_t smem = FATTN_SMEM(nsg); + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1); +#undef FATTN_SMEM + } else { + // half4x4 kernel + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nhptg = 1; // heads per threadgroup + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + + // note: for simplicity assume the K is larger or equal than V + GGML_ASSERT(ne10 >= ne20); + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne20*(nsg) + // each simdgroup has a full f32 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsg = 1; + + // workgroups + // each workgroup handles nsg*nkpsg cache values + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } + + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + + GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + + // using 1 workgroup -> write the result directly into dst + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); + } else { + // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); + GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); + + // write the results from each workgroup into a temp buffer + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); + + // sync the 2 kernels + ggml_metal_op_concurrency_reset(ctx); + + // reduce the results from the workgroups + { + const int32_t nrows = ne1*ne2*ne3; + + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { + nrows, + }; + + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1); + } + } +#undef FATTN_SMEM + } + + return 1; +} + +int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ 0, + /*.o1 =*/ { bid_src1.offs }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + // c[0] = add(a, b[0]) + // c[1] = add(c[0], b[1]) + // c[2] = add(c[1], b[2]) + // ... + if (use_fusion) { + fops[0] = GGML_OP_ADD; + fops[1] = GGML_OP_ADD; + fops[2] = GGML_OP_ADD; + fops[3] = GGML_OP_ADD; + fops[4] = GGML_OP_ADD; + fops[5] = GGML_OP_ADD; + fops[6] = GGML_OP_ADD; + fops[7] = GGML_OP_ADD; + + // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops + // across splits. idx_end indicates the last node in the current split + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { + break; + } + + // b[0] === b[1] === ... + if (!ggml_are_same_layout(f0->src[1], f1->src[1])) { + break; + } + + // only fuse ops if src1 is in the same Metal buffer + ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]); + if (bid_fuse.metal != bid_src1.metal) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + args.o1[n_fuse + 1] = bid_fuse.offs; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + } + } + + // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer + bid_src1.offs = 0; + + struct ggml_metal_pipeline_with_params pipeline; + + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_dst, 3); + + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + } + + return n_fuse; +} + +int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + }; + + auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t ngrp = ((const int32_t *) op->op_params)[0]; + + float eps; + memcpy(&eps, op->op_params + 1, sizeof(float)); + + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ngrp =*/ ngrp, + /*.eps =*/ eps, + }; + + auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + + int nth = 32; // SIMD width + //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + // nth *= 2; + //} + + //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + //nth = std::min(nth, ne00/4); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; + + // d[0] = norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (use_fusion) { + fops[0] = op->op; + fops[1] = GGML_OP_MUL; + fops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { + break; + } + + if (f1->src[1]->ne[0] != op->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(f1->src[1])) { + break; + } + + if (f1->type != GGML_TYPE_F32) { + break; + } + + //ctx->fuse_cnt[f1->op]++; + + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]); + + args.nef1[n_fuse + 1] = f1->src[1]->ne[1]; + args.nef2[n_fuse + 1] = f1->src[1]->ne[2]; + args.nef3[n_fuse + 1] = f1->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = f1->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = f1->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = f1->src[1]->nb[3]; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op)); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op)); + } + } + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + + int nth = 32; // SIMD width + + while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, args.ne00_t); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return n_fuse; +} + +int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = std::min(1024, ne00); + + const int n_past = ((const int32_t *) op->op_params)[0]; + const int n_dims = ((const int32_t *) op->op_params)[1]; + //const int mode = ((const int32_t *) op->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) op->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float)); + + // mrope + const int sect_0 = ((const int32_t *) op->op_params)[11]; + const int sect_1 = ((const int32_t *) op->op_params)[12]; + const int sect_2 = ((const int32_t *) op->op_params)[13]; + const int sect_3 = ((const int32_t *) op->op_params)[14]; + + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, + /* src2 =*/ op->src[2] != nullptr, + }; + + auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + if (op->src[2]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t p0 = ((const int32_t *)(op->op_params))[2]; + const int32_t p1 = ((const int32_t *)(op->op_params))[3]; + const int32_t d0 = ((const int32_t *)(op->op_params))[4]; + const int32_t d1 = ((const int32_t *)(op->op_params))[5]; + + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + + const int32_t N = op->src[1]->ne[is_2D ? 3 : 2]; + const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? op->src[1]->ne[1] : 1; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = is_2D ? op->src[0]->ne[1] : 1; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OH = is_2D ? op->ne[2] : 1; + const int32_t OW = op->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; + + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + + auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + + GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + + return 1; +} + +int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *) op->op_params)[0]; + const int32_t s1 = ((const int32_t *) op->op_params)[1]; + const int32_t p0 = ((const int32_t *) op->op_params)[2]; + const int32_t p1 = ((const int32_t *) op->op_params)[3]; + const int32_t d0 = ((const int32_t *) op->op_params)[4]; + const int32_t d1 = ((const int32_t *) op->op_params)[5]; + + ggml_metal_kargs_conv_2d args = { + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.IW =*/ ne10, + /*.IH =*/ ne11, + /*.KW =*/ ne00, + /*.KH =*/ ne01, + /*.IC =*/ ne02, + /*.OC =*/ ne03, + /*.OW =*/ ne0, + /*.OH =*/ ne1, + /*.N =*/ ne3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + }; + + auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); + + int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); + nth = std::min(nth, 256); + nth = std::max(nth, 1); + + const uint64_t n_out = ggml_nelements(op); + + uint64_t tg = (n_out + nth - 1)/nth; + tg = std::max<uint64_t>(tg, 1); + tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max()); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[1]; + const int32_t IL = op->src[1]->ne[0]; + + const int32_t K = op->src[0]->ne[0]; + + const int32_t OL = op->ne[0]; + const int32_t OC = op->ne[1]; + + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[2]; + const int32_t IH = op->src[1]->ne[1]; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = op->src[0]->ne[1]; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OW = op->ne[0]; + const int32_t OH = op->ne[1]; + const int32_t OC = op->ne[2]; + + ggml_metal_kargs_conv_transpose_2d args = { + /*.IC =*/ IC, + /*.IH =*/ IH, + /*.IW =*/ IW, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.OC =*/ OC, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + // Metal requires buffer size to be multiple of 16 bytes + const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1); + + return 1; +} + +int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const float sf0 = (float)ne0/op->src[0]->ne[0]; + const float sf1 = (float)ne1/op->src[0]->ne[1]; + const float sf2 = (float)ne2/op->src[0]->ne[2]; + const float sf3 = (float)ne3/op->src[0]->ne[3]; + + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + + auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ ((const int32_t *)(op->op_params))[0], + /*.p1 =*/ ((const int32_t *)(op->op_params))[1] + }; + + auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + float start; + float step; + + memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float)); + + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + + const int nth = std::min(1024, ne0); + + auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int dim = op->op_params[0]; + const int max_period = op->op_params[1]; + + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period, + }; + + auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + + const int nth = std::max(1, std::min(1024, dim/2)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_argmax args = { + /*.ne00 = */ ne00, + /*.nb01 = */ nb01, + }; + + auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + + const int64_t nrows = ggml_nrows(op->src[0]); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + const int npr = (ne00 + nth - 1)/nth; + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += ggml_nbytes(op); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + + int len = nth; + + while (len < ne00) { + ggml_metal_op_concurrency_reset(ctx); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, + }; + + // merges per row + const int nm = (ne00 + 2*len - 1) / (2*len); + + const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)); + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + +int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_adamw args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_sgd args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + + { + ggml_metal_kargs_memset args = { /*.val =*/ 0 }; + + auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + } + + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_kargs_count_equal args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + }; + + auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op); + + const size_t smem = pipeline.smem; + + const int nth = 32*pipeline.nsg; + + GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + } + + return 1; +} |
