summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp')
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp4222
1 files changed, 4222 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp
new file mode 100644
index 0000000..7db95d1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -0,0 +1,4222 @@
+#include "ggml-metal-ops.h"
+
+#include "ggml.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-metal-impl.h"
+#include "ggml-metal-common.h"
+#include "ggml-metal-device.h"
+
+#include <cassert>
+#include <algorithm>
+#include <limits>
+#include <cmath>
+
+static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
+ if (!t) {
+ return { nullptr, 0 };
+ }
+
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
+
+ ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;
+
+ return ggml_metal_buffer_get_id(ctx, t);
+}
+
+struct ggml_metal_op {
+ ggml_metal_op(
+ ggml_metal_device_t dev,
+ ggml_metal_cmd_buf_t cmd_buf,
+ ggml_cgraph * gf,
+ int idx_start,
+ int idx_end,
+ bool use_fusion,
+ bool use_concurrency,
+ bool use_capture,
+ int debug_graph,
+ int debug_fusion) {
+ this->dev = dev;
+ this->lib = ggml_metal_device_get_library(dev);
+ this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
+ this->mem_ranges = ggml_mem_ranges_init(debug_graph);
+ this->idx_start = idx_start;
+ this->idx_end = idx_end;
+ this->use_fusion = use_fusion;
+ this->use_concurrency = use_concurrency;
+ this->use_capture = use_capture;
+ this->debug_graph = debug_graph;
+ this->debug_fusion = debug_fusion;
+ this->gf = gf;
+
+ idxs.reserve(gf->n_nodes);
+
+ // filter empty nodes
+ // TODO: this can be removed when the allocator starts filtering them earlier
+ // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
+ for (int i = idx_start; i < idx_end; i++) {
+ if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
+ idxs.push_back(i);
+ }
+ }
+ }
+
+ ~ggml_metal_op() {
+ ggml_metal_encoder_end_encoding(this->enc);
+ ggml_metal_encoder_free(this->enc);
+ ggml_mem_ranges_free(this->mem_ranges);
+ }
+
+ int n_nodes() const {
+ return idxs.size();
+ }
+
+ ggml_tensor * node(int i) const {
+ assert(i >= 0 && i < (int) idxs.size());
+ return ggml_graph_node(gf, idxs[i]);
+ }
+
+ bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
+ assert(use_fusion);
+ assert(i0 >= 0 && i0 < n_nodes());
+
+ if (i0 + n_ops > n_nodes()) {
+ return false;
+ }
+
+ return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
+ }
+
+ ggml_metal_device_t dev;
+ ggml_metal_library_t lib;
+ ggml_metal_encoder_t enc;
+ ggml_mem_ranges_t mem_ranges;
+
+ bool use_fusion;
+ bool use_concurrency;
+ bool use_capture;
+
+ int debug_graph;
+ int debug_fusion;
+
+private:
+ ggml_cgraph * gf;
+
+ int idx_start;
+ int idx_end;
+
+ // non-empty node indices
+ std::vector<int> idxs;
+};
+
+ggml_metal_op_t ggml_metal_op_init(
+ ggml_metal_device_t dev,
+ ggml_metal_cmd_buf_t cmd_buf,
+ ggml_cgraph * gf,
+ int idx_start,
+ int idx_end,
+ bool use_fusion,
+ bool use_concurrency,
+ bool use_capture,
+ int debug_graph,
+ int debug_fusion) {
+ ggml_metal_op_t res = new ggml_metal_op(
+ dev,
+ cmd_buf,
+ gf,
+ idx_start,
+ idx_end,
+ use_fusion,
+ use_concurrency,
+ use_capture,
+ debug_graph,
+ debug_fusion);
+
+ return res;
+}
+
+void ggml_metal_op_free(ggml_metal_op_t ctx) {
+ delete ctx;
+}
+
+int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
+ return ctx->n_nodes();
+}
+
+static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
+ if (!ctx->mem_ranges) {
+ return true;
+ }
+
+ ggml_metal_encoder_memory_barrier(ctx->enc);
+
+ ggml_mem_ranges_reset(ctx->mem_ranges);
+
+ return true;
+}
+
+static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {
+ if (!ctx->mem_ranges) {
+ return false;
+ }
+
+ return ggml_mem_ranges_check(ctx->mem_ranges, node);
+}
+
+static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {
+ if (!ctx->mem_ranges) {
+ return true;
+ }
+
+ return ggml_mem_ranges_add(ctx->mem_ranges, node);
+}
+
+static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
+ struct ggml_tensor * node = ctx->node(idx);
+
+ //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
+
+ if (ggml_is_empty(node)) {
+ return 1;
+ }
+
+ switch (node->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ {
+ // noop -> next node
+ if (ctx->debug_graph > 0) {
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
+ }
+ } return 1;
+ default:
+ {
+ } break;
+ }
+
+ if (!ggml_metal_device_supports_op(ctx->dev, node)) {
+ GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node));
+ GGML_ABORT("unsupported op");
+ }
+
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+ return 1;
+ }
+
+ int n_fuse = 1;
+
+ // check if the current node can run concurrently with other nodes before it
+ // the condition is that:
+ // - the current node cannot write to any previous src or dst ranges
+ // - the current node cannot read from any previous dst ranges
+ //
+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
+ //
+ {
+ const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);
+
+ if (!is_concurrent) {
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+
+ if (ctx->debug_graph > 0) {
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
+ }
+ if (ctx->debug_graph > 1) {
+ GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
+ GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
+ GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
+ GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
+ GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
+
+ if (node->src[0]) {
+ GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+ ggml_is_contiguous(node->src[0]), node->src[0]->name);
+ }
+ if (node->src[1]) {
+ GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+ ggml_is_contiguous(node->src[1]), node->src[1]->name);
+ }
+ if (node->src[2]) {
+ GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
+ ggml_is_contiguous(node->src[2]), node->src[2]->name);
+ }
+ if (node->src[3]) {
+ GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
+ ggml_is_contiguous(node->src[3]), node->src[3]->name);
+ }
+ if (node) {
+ GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+ node->name);
+ }
+ }
+ }
+
+ switch (node->op) {
+ case GGML_OP_CONCAT:
+ {
+ n_fuse = ggml_metal_op_concat(ctx, idx);
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ {
+ n_fuse = ggml_metal_op_bin(ctx, idx);
+ } break;
+ case GGML_OP_ADD_ID:
+ {
+ n_fuse = ggml_metal_op_add_id(ctx, idx);
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ n_fuse = ggml_metal_op_repeat(ctx, idx);
+ } break;
+ case GGML_OP_ACC:
+ {
+ n_fuse = ggml_metal_op_acc(ctx, idx);
+ } break;
+ case GGML_OP_SCALE:
+ case GGML_OP_FILL:
+ case GGML_OP_CLAMP:
+ case GGML_OP_LEAKY_RELU:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
+ case GGML_OP_LOG:
+ case GGML_OP_UNARY:
+ {
+ n_fuse = ggml_metal_op_unary(ctx, idx);
+ } break;
+ case GGML_OP_GLU:
+ {
+ n_fuse = ggml_metal_op_glu(ctx, idx);
+ } break;
+ case GGML_OP_SUM:
+ {
+ n_fuse = ggml_metal_op_sum(ctx, idx);
+ } break;
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ {
+ n_fuse = ggml_metal_op_sum_rows(ctx, idx);
+ } break;
+ case GGML_OP_CUMSUM:
+ {
+ n_fuse = ggml_metal_op_cumsum(ctx, idx);
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ n_fuse = ggml_metal_op_soft_max(ctx, idx);
+ } break;
+ case GGML_OP_SSM_CONV:
+ {
+ n_fuse = ggml_metal_op_ssm_conv(ctx, idx);
+ } break;
+ case GGML_OP_SSM_SCAN:
+ {
+ n_fuse = ggml_metal_op_ssm_scan(ctx, idx);
+ } break;
+ case GGML_OP_RWKV_WKV6:
+ case GGML_OP_RWKV_WKV7:
+ {
+ n_fuse = ggml_metal_op_rwkv(ctx, idx);
+ } break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ n_fuse = ggml_metal_op_mul_mat(ctx, idx);
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ n_fuse = ggml_metal_op_get_rows(ctx, idx);
+ } break;
+ case GGML_OP_SET_ROWS:
+ {
+ n_fuse = ggml_metal_op_set_rows(ctx, idx);
+ } break;
+ case GGML_OP_DIAG:
+ {
+ n_fuse = ggml_metal_op_diag(ctx, idx);
+ } break;
+ case GGML_OP_L2_NORM:
+ {
+ n_fuse = ggml_metal_op_l2_norm(ctx, idx);
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ n_fuse = ggml_metal_op_group_norm(ctx, idx);
+ } break;
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ {
+ n_fuse = ggml_metal_op_norm(ctx, idx);
+ } break;
+ case GGML_OP_ROPE:
+ {
+ n_fuse = ggml_metal_op_rope(ctx, idx);
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ n_fuse = ggml_metal_op_im2col(ctx, idx);
+ } break;
+ case GGML_OP_CONV_2D:
+ {
+ n_fuse = ggml_metal_op_conv_2d(ctx, idx);
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ n_fuse = ggml_metal_op_upscale(ctx, idx);
+ } break;
+ case GGML_OP_PAD:
+ {
+ n_fuse = ggml_metal_op_pad(ctx, idx);
+ } break;
+ case GGML_OP_PAD_REFLECT_1D:
+ {
+ n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
+ } break;
+ case GGML_OP_ARANGE:
+ {
+ n_fuse = ggml_metal_op_arange(ctx, idx);
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ n_fuse = ggml_metal_op_argsort(ctx, idx);
+ } break;
+ case GGML_OP_TOP_K:
+ {
+ n_fuse = ggml_metal_op_top_k(ctx, idx);
+ } break;
+ case GGML_OP_TRI:
+ {
+ n_fuse = ggml_metal_op_tri(ctx, idx);
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ {
+ n_fuse = ggml_metal_op_cpy(ctx, idx);
+ } break;
+ case GGML_OP_POOL_1D:
+ {
+ n_fuse = ggml_metal_op_pool_1d(ctx, idx);
+ } break;
+ case GGML_OP_POOL_2D:
+ {
+ n_fuse = ggml_metal_op_pool_2d(ctx, idx);
+ } break;
+ case GGML_OP_ARGMAX:
+ {
+ n_fuse = ggml_metal_op_argmax(ctx, idx);
+ } break;
+ case GGML_OP_OPT_STEP_ADAMW:
+ {
+ n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
+ } break;
+ case GGML_OP_OPT_STEP_SGD:
+ {
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
+ } break;
+ case GGML_OP_COUNT_EQUAL:
+ {
+ n_fuse = ggml_metal_op_count_equal(ctx, idx);
+ } break;
+ default:
+ {
+ GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
+ GGML_ABORT("fatal error");
+ }
+ }
+
+ if (ctx->debug_graph > 0) {
+ if (n_fuse > 1) {
+ GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
+ }
+ }
+
+ // update the mem ranges in the encoding context
+ for (int i = 0; i < n_fuse; ++i) {
+ if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+ }
+
+ return n_fuse;
+}
+
+int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
+ if (ctx->use_capture) {
+ ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
+ }
+
+ int res = ggml_metal_op_encode_impl(ctx, idx);
+ if (idx + res > ctx->n_nodes()) {
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
+ }
+
+ if (ctx->use_capture) {
+ ggml_metal_encoder_debug_group_pop(ctx->enc);
+ }
+
+ return res;
+}
+
+int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t dim = ((const int32_t *) op->op_params)[0];
+
+ ggml_metal_kargs_concat args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.dim =*/ dim,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ const int nth = std::min(1024, ne0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
+
+ ggml_metal_kargs_repeat args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+ const size_t offs = ((const int32_t *) op->op_params)[3];
+
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
+
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+ ggml_metal_kargs_cpy args = {
+ /*.nk0 =*/ ne00,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+
+ ggml_metal_kargs_bin args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ pnb1,
+ /*.nb02 =*/ pnb2,
+ /*.nb03 =*/ pnb3,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ pnb1,
+ /*.nb2 =*/ pnb2,
+ /*.nb3 =*/ pnb3,
+ /*.offs =*/ offs,
+ /*.o1 =*/ { 0 },
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_kargs_unary args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.slope =*/ 0.0,
+ /*.scale =*/ 0.0,
+ /*.bias =*/ 0.0,
+ /*.val =*/ 0.0,
+ /*.min =*/ 0.0,
+ /*.max =*/ 0.0,
+ };
+
+ if (op->op == GGML_OP_LEAKY_RELU) {
+ args.slope = ggml_get_op_params_f32(op, 0);
+ }
+
+ if (op->op == GGML_OP_SCALE) {
+ args.scale = ggml_get_op_params_f32(op, 0);
+ args.bias = ggml_get_op_params_f32(op, 1);
+ }
+
+ if (op->op == GGML_OP_FILL) {
+ args.val = ggml_get_op_params_f32(op, 0);
+ }
+
+ if (op->op == GGML_OP_CLAMP) {
+ args.min = ggml_get_op_params_f32(op, 0);
+ args.max = ggml_get_op_params_f32(op, 1);
+ }
+
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne0 = ne0/4;
+ }
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ if (pipeline.cnt) {
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+ } else {
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ const int nth = MIN(args.ne00, nth_max);
+
+ const int nk0 = (args.ne00 + nth - 1)/nth;
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ if (op->src[1]) {
+ GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
+ }
+
+ auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
+
+ const int32_t swp = ggml_get_op_params_i32(op, 1);
+ const float alpha = ggml_get_op_params_f32(op, 2);
+ const float limit = ggml_get_op_params_f32(op, 3);
+
+ const int32_t i00 = swp ? ne0 : 0;
+ const int32_t i10 = swp ? 0 : ne0;
+
+ ggml_metal_kargs_glu args = {
+ /*.ne00 =*/ ne00,
+ /*.nb01 =*/ nb01,
+ /*.ne10 =*/ op->src[1] ? ne10 : ne00,
+ /*.nb11 =*/ op->src[1] ? nb11 : nb01,
+ /*.ne0 =*/ ne0,
+ /*.nb1 =*/ nb1,
+ /*.i00 =*/ op->src[1] ? 0 : i00,
+ /*.i10 =*/ op->src[1] ? 0 : i10,
+ /*.alpha=*/ alpha,
+ /*.limit=*/ limit
+ };
+
+ const int64_t nrows = ggml_nrows(op->src[0]);
+
+ const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ if (op->src[1]) {
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ } else {
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2);
+ }
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
+
+ ggml_metal_kargs_sum args = {
+ /*.np =*/ n,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
+
+ int nth = 32; // SIMD width
+
+ while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, (int) n);
+
+ const int nsg = (nth + 31) / 32;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_sum_rows args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, ne00);
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
+
+ int nth = 1;
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
+ nth *= 2;
+ }
+
+ GGML_ASSERT(ne00 <= nth*nth);
+
+ const int64_t net0 = (ne00 + nth - 1) / nth;
+ const int64_t net1 = ne01;
+ const int64_t net2 = ne02;
+ const int64_t net3 = ne03;
+
+ const uint64_t nbt0 = sizeof(float);
+ const uint64_t nbt1 = net0*nbt0;
+ const uint64_t nbt2 = net1*nbt1;
+ const uint64_t nbt3 = net2*nbt2;
+
+ const size_t smem = GGML_PAD(32*sizeof(float), 16);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_dst;
+ bid_tmp.offs += ggml_nbytes(op);
+
+ {
+ ggml_metal_kargs_cumsum_blk args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.net0 =*/ net0,
+ /*.net1 =*/ net1,
+ /*.net2 =*/ net2,
+ /*.net3 =*/ net3,
+ /*.nbt0 =*/ nbt0,
+ /*.nbt1 =*/ nbt1,
+ /*.nbt2 =*/ nbt2,
+ /*.nbt3 =*/ nbt3,
+ /*.outb =*/ ne00 > nth,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
+ }
+
+ if (ne00 > nth) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ {
+ ggml_metal_kargs_cumsum_blk args = {
+ /*.ne00 =*/ net0,
+ /*.ne01 =*/ net1,
+ /*.ne02 =*/ net2,
+ /*.ne03 =*/ net3,
+ /*.nb00 =*/ nbt0,
+ /*.nb01 =*/ nbt1,
+ /*.nb02 =*/ nbt2,
+ /*.nb03 =*/ nbt3,
+ /*.net0 =*/ net0,
+ /*.net1 =*/ net1,
+ /*.net2 =*/ net2,
+ /*.net3 =*/ net3,
+ /*.nbt0 =*/ nbt0,
+ /*.nbt1 =*/ nbt1,
+ /*.nbt2 =*/ nbt2,
+ /*.nbt3 =*/ nbt3,
+ /*.outb =*/ false,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
+ }
+
+ ggml_metal_op_concurrency_reset(ctx);
+
+ {
+ auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
+
+ ggml_metal_kargs_cumsum_add args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.net0 =*/ net0,
+ /*.net1 =*/ net1,
+ /*.net2 =*/ net2,
+ /*.net3 =*/ net3,
+ /*.nbt0 =*/ nbt0,
+ /*.nbt1 =*/ nbt1,
+ /*.nbt2 =*/ nbt2,
+ /*.nbt3 =*/ nbt3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_add);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
+ }
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
+
+ ggml_metal_kargs_get_rows args = {
+ /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
+ /*.ne00 =*/ ne00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ const int nw0 = (args.ne00t + nth - 1)/nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
+
+ const int32_t nk0 = ne0/ggml_blck_size(op->type);
+
+ int nth = 32; // SIMD width
+
+ while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ int nrptg = 1;
+ if (nth > nk0) {
+ nrptg = (nth + nk0 - 1)/nk0;
+ nth = nk0;
+
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nrptg--;
+ }
+ }
+
+ nth = std::min(nth, nk0);
+
+ ggml_metal_kargs_set_rows args = {
+ /*.nk0 =*/ nk0,
+ /*.ne01 =*/ ne01,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_diag args = {
+ /*.ne00 =*/ne00,
+ /*.ne01 =*/ne01,
+ /*.ne02 =*/ne02,
+ /*.ne03 =*/ne03,
+ /*.nb00 =*/nb00,
+ /*.nb01 =*/nb01,
+ /*.nb02 =*/nb02,
+ /*.nb03 =*/nb03,
+ /*.ne0 =*/ne0,
+ /*.ne1 =*/ne1,
+ /*.ne2 =*/ne2,
+ /*.ne3 =*/ne3,
+ /*.nb0 =*/nb0,
+ /*.nb1 =*/nb1,
+ /*.nb2 =*/nb2,
+ /*.nb3 =*/nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ float scale;
+ float max_bias;
+
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
+
+ const uint32_t n_head = op->src[0]->ne[2];
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ // softmax
+
+ ggml_metal_kargs_soft_max args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
+
+ int nth = 32; // SIMD width
+
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ } else {
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ }
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ if (op->src[1]) {
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ } else {
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
+ }
+ if (op->src[2]) {
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);
+ } else {
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
+ }
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_ssm_conv args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ };
+
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
+ const bool use_batched = (ne1 > 1);
+
+ if (use_batched) {
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
+ int BATCH_SIZE;
+ if (ne1 > 128) BATCH_SIZE = 256;
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
+ else BATCH_SIZE = 2;
+
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
+ } else {
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const ggml_tensor * src3 = op->src[3];
+ const ggml_tensor * src4 = op->src[4];
+ const ggml_tensor * src5 = op->src[5];
+ const ggml_tensor * src6 = op->src[6];
+
+ GGML_ASSERT(src3);
+ GGML_ASSERT(src4);
+ GGML_ASSERT(src5);
+ GGML_ASSERT(src6);
+
+ const int64_t d_state = ne00;
+ const int64_t d_inner = ne01;
+ const int64_t n_head = ne02;
+ const int64_t n_group = ne41;
+ const int64_t n_seq_tokens = ne12;
+ const int64_t n_seqs = ne13;
+
+ ggml_metal_kargs_ssm_scan args = {
+ /*.d_state =*/ d_state,
+ /*.d_inner =*/ d_inner,
+ /*.n_head =*/ n_head,
+ /*.n_group =*/ n_group,
+ /*.n_seq_tokens =*/ n_seq_tokens,
+ /*.n_seqs =*/ n_seqs,
+ /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.ns12 =*/ nb12/nb10,
+ /*.nb13 =*/ nb13,
+ /*.nb20 =*/ nb20,
+ /*.nb21 =*/ nb21,
+ /*.ns21 =*/ nb21/nb20,
+ /*.nb22 =*/ nb22,
+ /*.ne30 =*/ ne30,
+ /*.nb31 =*/ nb31,
+ /*.nb41 =*/ nb41,
+ /*.nb42 =*/ nb42,
+ /*.ns42 =*/ nb42/nb40,
+ /*.nb43 =*/ nb43,
+ /*.nb51 =*/ nb51,
+ /*.nb52 =*/ nb52,
+ /*.ns52 =*/ nb52/nb50,
+ /*.nb53 =*/ nb53,
+ /*.nb0 =*/ nb0,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
+
+ GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
+ const int64_t T = op->src[0]->ne[2];
+ const int64_t C = op->ne[0];
+ const int64_t H = op->src[0]->ne[1];
+
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
+
+ int ida = 0;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
+ if (op->op == GGML_OP_RWKV_WKV7) {
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
+ }
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
+ ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
+ ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
+ ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
+ ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ const int nsg = pipeline.nsg;
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+ GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
+
+ int64_t nk0 = ne00;
+ if (ggml_is_quantized(op->src[0]->type)) {
+ nk0 = ne00/16;
+ } else if (ggml_is_quantized(op->type)) {
+ nk0 = ne00/ggml_blck_size(op->type);
+ }
+
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ // when rows are small, we can batch them together in a single threadgroup
+ int nrptg = 1;
+
+ // TODO: relax this constraint in the future
+ if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
+ if (nth > nk0) {
+ nrptg = (nth + nk0 - 1)/nk0;
+ nth = nk0;
+
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nrptg--;
+ }
+ }
+ }
+
+ nth = std::min<int>(nth, nk0);
+
+ ggml_metal_kargs_cpy args = {
+ /*.nk0 =*/ nk0,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t * opts = op->op_params;
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
+
+ const int32_t k0 = opts[1];
+ const int32_t s0 = opts[2];
+ const int32_t p0 = opts[3];
+
+ const int64_t IW = op->src[0]->ne[0];
+ const int64_t OW = op->ne[0];
+
+ const int64_t np = ggml_nelements(op);
+
+ ggml_metal_kargs_pool_1d args_pool_1d = {
+ /* .k0 = */ k0,
+ /* .s0 = */ s0,
+ /* .p0 = */ p0,
+ /* .IW = */ IW,
+ /* .OW = */ OW,
+ /* .np = */ np
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
+ const int ntg = (np + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+
+int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t * opts = op->op_params;
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
+
+ const int32_t k0 = opts[1];
+ const int32_t k1 = opts[2];
+ const int32_t s0 = opts[3];
+ const int32_t s1 = opts[4];
+ const int32_t p0 = opts[5];
+ const int32_t p1 = opts[6];
+
+ const int64_t IH = op->src[0]->ne[1];
+ const int64_t IW = op->src[0]->ne[0];
+
+ const int64_t N = op->ne[3];
+ const int64_t OC = op->ne[2];
+ const int64_t OH = op->ne[1];
+ const int64_t OW = op->ne[0];
+
+ const int64_t np = N * OC * OH * OW;
+
+ ggml_metal_kargs_pool_2d args_pool_2d = {
+ /* .k0 = */ k0,
+ /* .k1 = */ k1,
+ /* .s0 = */ s0,
+ /* .s1 = */ s1,
+ /* .p0 = */ p0,
+ /* .p1 = */ p1,
+ /* .IH = */ IH,
+ /* .IW = */ IW,
+ /* .OH = */ OH,
+ /* .OW = */ OW,
+ /* .np = */ np
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
+ const int ntg = (np + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(ne00 == ne10);
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const int16_t r2 = ne12/ne02;
+ const int16_t r3 = ne13/ne03;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ const int ne11_mm_min = 8;
+
+ // first try to use small-batch mat-mv kernels
+ // these should be efficient for BS [2, ~8]
+ if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&
+ (
+ (
+ (
+ op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
+ op->src[0]->type == GGML_TYPE_F16 ||
+ op->src[0]->type == GGML_TYPE_Q4_0 ||
+ op->src[0]->type == GGML_TYPE_Q4_1 ||
+ op->src[0]->type == GGML_TYPE_Q5_0 ||
+ op->src[0]->type == GGML_TYPE_Q5_1 ||
+ op->src[0]->type == GGML_TYPE_Q8_0 ||
+ op->src[0]->type == GGML_TYPE_MXFP4 ||
+ op->src[0]->type == GGML_TYPE_IQ4_NL ||
+ false) && (ne11 >= 2 && ne11 <= 8)
+ ) ||
+ (
+ (
+ op->src[0]->type == GGML_TYPE_Q4_K ||
+ op->src[0]->type == GGML_TYPE_Q5_K ||
+ op->src[0]->type == GGML_TYPE_Q6_K ||
+ false) && (ne11 >= 4 && ne11 <= 8)
+ )
+ )
+ ) {
+ // TODO: determine the optimal parameters based on grid utilization
+ // I still don't know why we should not always use the maximum available threads:
+ //
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
+ //
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
+ // values and there can be some tail effects when nsg is high. need to confirm this
+ //
+ const int nsg = 2; // num simdgroups per threadgroup
+
+ // num threads along row per simdgroup
+ int16_t nxpsg = 0;
+ if (ne00 % 256 == 0 && ne11 < 3) {
+ nxpsg = 16;
+ } else if (ne00 % 128 == 0) {
+ nxpsg = 8;
+ } else {
+ nxpsg = 4;
+ }
+
+ const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
+ const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
+ int16_t r1ptg = 4; // num src1 rows per threadgroup
+
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
+ switch (ne11) {
+ case 2:
+ r1ptg = 2; break;
+ case 3:
+ case 6:
+ r1ptg = 3; break;
+ case 4:
+ case 7:
+ case 8:
+ r1ptg = 4; break;
+ case 5:
+ r1ptg = 5; break;
+ default:
+ GGML_ABORT("unsupported ne11");
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
+
+ ggml_metal_kargs_mul_mv_ext args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.r2 =*/ r2,
+ /*.r3 =*/ r3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
+ } else if (
+ !ggml_is_transposed(op->src[0]) &&
+ !ggml_is_transposed(op->src[1]) &&
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
+ //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // some Metal matrix data types require aligned pointers
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+ //switch (op->src[0]->type) {
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
+ // default: break;
+ //}
+
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
+
+ ggml_metal_kargs_mul_mm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne02 =*/ ne02,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.r2 =*/ r2,
+ /*.r3 =*/ r3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
+ } else {
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
+
+ const int nr0 = pipeline.nr0;
+ const int nr1 = pipeline.nr1;
+ const int nsg = pipeline.nsg;
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_kargs_mul_mv args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.nr0 =*/ nr0,
+ /*.r2 =*/ r2,
+ /*.r3 =*/ r3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ if (op->src[0]->type == GGML_TYPE_F32 ||
+ op->src[0]->type == GGML_TYPE_F16 ||
+ op->src[0]->type == GGML_TYPE_BF16 ||
+ op->src[0]->type == GGML_TYPE_Q8_0) {
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
+ } else {
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
+ }
+ }
+
+ return 1;
+}
+
+size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_MUL_MAT_ID);
+
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
+
+ return ggml_type_size(GGML_TYPE_I32)*ne02;
+}
+
+size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_MUL_MAT_ID);
+
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
+ const int64_t ne21 = op->src[2]->ne[1]; // n_token
+
+ return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
+}
+
+int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ // src2 = ids
+ GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(!ggml_is_transposed(op->src[0]));
+ GGML_ASSERT(!ggml_is_transposed(op->src[1]));
+
+ GGML_ASSERT(ne03 == 1);
+ GGML_ASSERT(ne13 == 1);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ const uint32_t r2 = 1;
+ const uint32_t r3 = 1;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ // ne20 = n_used_experts
+ // ne21 = n_rows (batch size)
+ const int ne21_mm_id_min = 32;
+
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
+ // some Metal matrix data types require aligned pointers
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+ //switch (op->src[0]->type) {
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
+ // default: break;
+ //}
+
+ // extra buffers for intermediate id mapping
+ ggml_metal_buffer_id bid_tpe = bid_dst;
+ bid_tpe.offs += ggml_nbytes(op);
+
+ ggml_metal_buffer_id bid_ids = bid_tpe;
+ bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);
+
+ {
+ ggml_metal_kargs_mul_mm_id_map0 args = {
+ ne02,
+ ne10,
+ ne11, // n_expert_used (bcast)
+ nb11,
+ nb12,
+ ne21, // n_tokens
+ ne20, // n_expert_used
+ nb21,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
+
+ const size_t smem = pipeline.smem;
+
+ GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_tpe, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_ids, 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
+ }
+
+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
+ ggml_metal_op_concurrency_reset(ctx);
+
+ {
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
+
+ ggml_metal_kargs_mul_mm_id args = {
+ /*.ne00 =*/ ne00,
+ /*.ne02 =*/ ne02,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne20 =*/ ne20, // n_expert_used
+ /*.ne21 =*/ ne21, // n_tokens
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.r2 =*/ r2,
+ /*.r3 =*/ r3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tpe, 3);
+ ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
+ }
+ } else {
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
+
+ const int nr0 = pipeline.nr0;
+ const int nr1 = pipeline.nr1;
+ const int nsg = pipeline.nsg;
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_kargs_mul_mv_id args = {
+ /*.nei0 =*/ ne20,
+ /*.nei1 =*/ ne21,
+ /*.nbi1 =*/ nb21,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.nb1 =*/ nb1,
+ /*.nr0 =*/ nr0,
+ };
+
+ if (ggml_is_quantized(op->src[0]->type)) {
+ GGML_ASSERT(ne00 >= nsg*nr0);
+ }
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 3);
+ ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
+
+ const int64_t _ne1 = 1;
+ const int64_t ne123 = ne20*ne21;
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ if (op->src[0]->type == GGML_TYPE_F32 ||
+ op->src[0]->type == GGML_TYPE_F16 ||
+ op->src[0]->type == GGML_TYPE_BF16 ||
+ op->src[0]->type == GGML_TYPE_Q8_0) {
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
+ } else {
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
+ }
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ ggml_metal_kargs_add_id args = {
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb11 =*/ nb11,
+ /*.nb21 =*/ nb21,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
+
+ return 1;
+}
+
+bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+ const int64_t ne00 = op->src[0]->ne[0]; // head size
+ const int64_t ne01 = op->src[0]->ne[1]; // batch size
+
+ // use vec kernel if the batch size is small and if the head size is supported
+ return (ne01 < 20) && (ne00 % 32 == 0);
+}
+
+size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+ size_t res = 0;
+
+ const bool has_mask = op->src[3] != nullptr;
+
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
+ GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
+
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+ if (false) {
+ // note: always reserve the padding space to avoid graph reallocations
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
+ const bool has_kvpad = true;
+
+ if (has_kvpad) {
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
+ nb11*ne12*ne13 +
+ nb21*ne22*ne23 +
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
+ }
+ } else {
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
+ const bool has_kvpad = true;
+
+ if (has_kvpad) {
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
+ nb11*ne12*ne13 +
+ nb21*ne22*ne23 +
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
+ }
+ }
+
+ return res;
+}
+
+size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+ size_t res = 0;
+
+ const bool has_mask = op->src[3] != nullptr;
+
+ if (!has_mask) {
+ return res;
+ }
+
+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
+
+ // this optimization is not useful for the vector kernels
+ // note: always reserve the blk buffer to avoid graph reallocations
+ //if (is_vec) {
+ // return res;
+ //}
+
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
+
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
+
+ res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
+
+ return res;
+}
+
+size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+ //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+ size_t res = 0;
+
+ // note: always reserve the temp buffer to avoid graph reallocations
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+ if (true) {
+ const int64_t nwg = 32;
+ const int64_t ne01_max = std::min(ne01, 32);
+
+ // temp buffer for writing the results from each workgroup
+ // - ne20: the size of the Value head
+ // - + 2: the S and M values for each intermediate result
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
+ }
+
+ return res;
+}
+
+int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
+
+ GGML_ASSERT(ne00 % 4 == 0);
+
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == op->src[2]->type);
+
+ //GGML_ASSERT(ggml_are_same_shape (src1, src2));
+ GGML_ASSERT(ne11 == ne21);
+ GGML_ASSERT(ne12 == ne22);
+
+ GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
+ GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
+
+ float scale;
+ float max_bias;
+ float logit_softcap;
+
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
+ memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
+
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
+
+ const bool has_mask = op->src[3] != NULL;
+ const bool has_sinks = op->src[4] != NULL;
+ const bool has_bias = max_bias != 0.0f;
+ const bool has_scap = logit_softcap != 0.0f;
+
+ const uint32_t n_head = op->src[0]->ne[2];
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ GGML_ASSERT(ne01 < 65536);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
+
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_pad = bid_dst;
+ bid_pad.offs += ggml_nbytes(op);
+
+ ggml_metal_buffer_id bid_blk = bid_pad;
+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_blk;
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
+
+ if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
+ // half8x8 kernel
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 8 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ bool need_sync = false;
+
+ const bool has_kvpad = ne11 % ncpsg != 0;
+
+ if (has_kvpad) {
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
+
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
+ /*.ne11 =*/ne11,
+ /*.ne_12_2 =*/ne12,
+ /*.ne_12_3 =*/ne13,
+ /*.nb11 =*/nb11,
+ /*.nb12 =*/nb12,
+ /*.nb13 =*/nb13,
+ /*.nb21 =*/nb21,
+ /*.nb22 =*/nb22,
+ /*.nb23 =*/nb23,
+ /*.ne31 =*/ne31,
+ /*.ne32 =*/ne32,
+ /*.ne33 =*/ne33,
+ /*.nb31 =*/nb31,
+ /*.nb32 =*/nb32,
+ /*.nb33 =*/nb33,
+ };
+
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
+
+ assert(ne12 == ne22);
+ assert(ne13 == ne23);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
+
+ need_sync = true;
+ }
+
+ if (has_mask) {
+ assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
+
+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
+ /*.ne01 =*/ ne01,
+ /*.ne30 =*/ ne30,
+ /*.ne31 =*/ ne31,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ };
+
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
+
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
+
+ need_sync = true;
+ }
+
+ if (need_sync) {
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+
+ const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
+
+ // 2*(2*ncpsg)
+ // ncpsg soft_max values + ncpsg mask values
+ //
+ // 16*32*(nsg)
+ // the shared memory needed for the simdgroups to load the KV cache
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
+
+ //int64_t nsgmax = 4;
+ //
+ //if (is_q) {
+ // nsgmax = 2;
+ // while (true) {
+ // const size_t smem = FATTN_SMEM(nsgmax);
+ // if (smem > props_dev->max_theadgroup_memory_size) {
+ // break;
+ // }
+ // nsgmax *= 2;
+ // }
+ // nsgmax /= 2;
+ //}
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ ggml_metal_kargs_flash_attn_ext args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ int32_t(nb11/nb10),
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ int32_t(nb21/nb20),
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne31 =*/ ne31,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
+#undef FATTN_SMEM
+ } else {
+ // half4x4 kernel
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int nhptg = 1; // heads per threadgroup
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 1 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ bool need_sync = false;
+
+ const bool has_kvpad = ne11 % ncpsg != 0;
+
+ if (has_kvpad) {
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
+
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
+ /*.ne11 =*/ne11,
+ /*.ne_12_2 =*/ne12,
+ /*.ne_12_3 =*/ne13,
+ /*.nb11 =*/nb11,
+ /*.nb12 =*/nb12,
+ /*.nb13 =*/nb13,
+ /*.nb21 =*/nb21,
+ /*.nb22 =*/nb22,
+ /*.nb23 =*/nb23,
+ /*.ne31 =*/ne31,
+ /*.ne32 =*/ne32,
+ /*.ne33 =*/ne33,
+ /*.nb31 =*/nb31,
+ /*.nb32 =*/nb32,
+ /*.nb33 =*/nb33,
+ };
+
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
+
+ assert(ne12 == ne22);
+ assert(ne13 == ne23);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
+
+ need_sync = true;
+ }
+
+ if (need_sync) {
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+
+ // note: for simplicity assume the K is larger or equal than V
+ GGML_ASSERT(ne10 >= ne20);
+
+ // ne00 + 2*ncpsg*(nsg)
+ // for each query, we load it as f16 in shared memory (ne00)
+ // and store the soft_max values and the mask
+ //
+ // ne20*(nsg)
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
+
+ int64_t nsg = 1;
+
+ // workgroups
+ // each workgroup handles nsg*nkpsg cache values
+ int32_t nwg = 1;
+ if (false) {
+ // for small KV caches, we could launch a single workgroup and write the results directly to dst/
+ // however, this does not lead to significant improvement, so disabled
+ nwg = 1;
+ nsg = 4;
+ } else {
+ nwg = 32;
+ nsg = 1;
+ while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
+ nsg *= 2;
+ }
+ }
+
+ ggml_metal_kargs_flash_attn_ext_vec args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ int32_t(nb11/nb10),
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ int32_t(nb21/nb20),
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne31 =*/ ne31,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
+
+ GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
+
+ if (nwg == 1) {
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
+
+ // using 1 workgroup -> write the result directly into dst
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
+ } else {
+ // sanity checks
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
+
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
+ GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
+
+ // write the results from each workgroup into a temp buffer
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
+ ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
+
+ // sync the 2 kernels
+ ggml_metal_op_concurrency_reset(ctx);
+
+ // reduce the results from the workgroups
+ {
+ const int32_t nrows = ne1*ne2*ne3;
+
+ ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
+ nrows,
+ };
+
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
+ }
+ }
+#undef FATTN_SMEM
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const bool use_fusion = ctx->use_fusion;
+
+ const int debug_fusion = ctx->debug_fusion;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_kargs_bin args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.offs =*/ 0,
+ /*.o1 =*/ { bid_src1.offs },
+ };
+
+ ggml_op fops[8];
+
+ int n_fuse = 1;
+
+ // c[0] = add(a, b[0])
+ // c[1] = add(c[0], b[1])
+ // c[2] = add(c[1], b[2])
+ // ...
+ if (use_fusion) {
+ fops[0] = GGML_OP_ADD;
+ fops[1] = GGML_OP_ADD;
+ fops[2] = GGML_OP_ADD;
+ fops[3] = GGML_OP_ADD;
+ fops[4] = GGML_OP_ADD;
+ fops[5] = GGML_OP_ADD;
+ fops[6] = GGML_OP_ADD;
+ fops[7] = GGML_OP_ADD;
+
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
+ // across splits. idx_end indicates the last node in the current split
+ for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
+ break;
+ }
+
+ ggml_tensor * f0 = ctx->node(idx + n_fuse);
+ ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
+
+ if (f0 != f1->src[0]) {
+ break;
+ }
+
+ // b[0] === b[1] === ...
+ if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
+ break;
+ }
+
+ // only fuse ops if src1 is in the same Metal buffer
+ ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
+ if (bid_fuse.metal != bid_src1.metal) {
+ break;
+ }
+
+ //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
+
+ args.o1[n_fuse + 1] = bid_fuse.offs;
+ }
+
+ ++n_fuse;
+
+ if (debug_fusion > 1 && n_fuse > 1) {
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
+ }
+ }
+
+ // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
+ bid_src1.offs = 0;
+
+ struct ggml_metal_pipeline_with_params pipeline;
+
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
+
+ if (n_fuse > 1) {
+ bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
+
+ for (int i = 1; i < n_fuse; ++i) {
+ if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ break;
+ }
+ }
+ }
+
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne10 = ne10/4;
+ args.ne0 = ne0/4;
+ }
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
+
+ if (pipeline.cnt) {
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+ } else {
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ int nth = 1;
+
+ while (2*nth < args.ne0 && nth < nth_max) {
+ nth *= 2;
+ }
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+ }
+
+ return n_fuse;
+}
+
+int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ float eps;
+ memcpy(&eps, op->op_params, sizeof(float));
+
+ ggml_metal_kargs_l2_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.eps =*/ eps,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
+
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne0 = ne0/4;
+ }
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t ngrp = ((const int32_t *) op->op_params)[0];
+
+ float eps;
+ memcpy(&eps, op->op_params + 1, sizeof(float));
+
+ ggml_metal_kargs_group_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.ngrp =*/ ngrp,
+ /*.eps =*/ eps,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
+
+ int nth = 32; // SIMD width
+ //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ // nth *= 2;
+ //}
+
+ //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ //nth = std::min(nth, ne00/4);
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ const bool use_fusion = ctx->use_fusion;
+
+ const int debug_fusion = ctx->debug_fusion;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ float eps;
+ memcpy(&eps, op->op_params, sizeof(float));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_kargs_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.eps =*/ eps,
+ /*.nef1 =*/ { ne01 },
+ /*.nef2 =*/ { ne02 },
+ /*.nef3 =*/ { ne03 },
+ /*.nbf1 =*/ { nb01 },
+ /*.nbf2 =*/ { nb02 },
+ /*.nbf3 =*/ { nb03 },
+ };
+
+ ggml_op fops[8];
+
+ int n_fuse = 1;
+
+ ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
+
+ // d[0] = norm(a)
+ // d[1] = mul(d[0], b)
+ // d[2] = add(d[1], c)
+ if (use_fusion) {
+ fops[0] = op->op;
+ fops[1] = GGML_OP_MUL;
+ fops[2] = GGML_OP_ADD;
+
+ for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
+ break;
+ }
+
+ ggml_tensor * f0 = ctx->node(idx + n_fuse);
+ ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
+
+ if (f0 != f1->src[0]) {
+ break;
+ }
+
+ if (f1->src[1]->ne[0] != op->ne[0]) {
+ break;
+ }
+
+ if (!ggml_is_contiguous_rows(f1->src[1])) {
+ break;
+ }
+
+ if (f1->type != GGML_TYPE_F32) {
+ break;
+ }
+
+ //ctx->fuse_cnt[f1->op]++;
+
+ bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
+
+ args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
+ args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
+ args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
+
+ args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
+ args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
+ args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
+ }
+
+ ++n_fuse;
+
+ if (debug_fusion > 1 && n_fuse > 1) {
+ if (n_fuse == 2) {
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
+ }
+ if (n_fuse == 3) {
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
+ }
+ }
+ }
+
+ if (n_fuse > 1) {
+ bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
+
+ for (int i = 1; i < n_fuse; ++i) {
+ if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ break;
+ }
+ }
+ }
+
+ auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
+
+ int nth = 32; // SIMD width
+
+ while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, args.ne00_t);
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return n_fuse;
+}
+
+int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ // make sure we have one or more position id(ne10) per token(ne02)
+ GGML_ASSERT(ne10 % ne02 == 0);
+ GGML_ASSERT(ne10 >= ne02);
+
+ const int nth = std::min(1024, ne00);
+
+ const int n_past = ((const int32_t *) op->op_params)[0];
+ const int n_dims = ((const int32_t *) op->op_params)[1];
+ //const int mode = ((const int32_t *) op->op_params)[2];
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+ const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
+
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float));
+
+ // mrope
+ const int sect_0 = ((const int32_t *) op->op_params)[11];
+ const int sect_1 = ((const int32_t *) op->op_params)[12];
+ const int sect_2 = ((const int32_t *) op->op_params)[13];
+ const int sect_3 = ((const int32_t *) op->op_params)[14];
+
+ ggml_metal_kargs_rope args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.n_past =*/ n_past,
+ /*.n_dims =*/ n_dims,
+ /*.n_ctx_orig =*/ n_ctx_orig,
+ /*.freq_base =*/ freq_base,
+ /*.freq_scale =*/ freq_scale,
+ /*.ext_factor =*/ ext_factor,
+ /*.attn_factor =*/ attn_factor,
+ /*.beta_fast =*/ beta_fast,
+ /*.beta_slow =*/ beta_slow,
+ /* sect_0 =*/ sect_0,
+ /* sect_1 =*/ sect_1,
+ /* sect_2 =*/ sect_2,
+ /* sect_3 =*/ sect_3,
+ /* src2 =*/ op->src[2] != nullptr,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ if (op->src[2]) {
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
+ } else {
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3);
+ }
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(op->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(op->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(op->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(op->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(op->op_params))[5];
+
+ const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
+
+ const int32_t N = op->src[1]->ne[is_2D ? 3 : 2];
+ const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
+ const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
+ const int32_t IW = op->src[1]->ne[0];
+
+ const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
+ const int32_t KW = op->src[0]->ne[0];
+
+ const int32_t OH = is_2D ? op->ne[2] : 1;
+ const int32_t OW = op->ne[1];
+
+ const int32_t CHW = IC * KH * KW;
+
+ const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
+ const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
+
+ ggml_metal_kargs_im2col args = {
+ /*.ofs0 =*/ ofs0,
+ /*.ofs1 =*/ ofs1,
+ /*.IW =*/ IW,
+ /*.IH =*/ IH,
+ /*.CHW =*/ CHW,
+ /*.s0 =*/ s0,
+ /*.s1 =*/ s1,
+ /*.p0 =*/ p0,
+ /*.p1 =*/ p1,
+ /*.d0 =*/ d0,
+ /*.d1 =*/ d1,
+ /*.N =*/ N,
+ /*.KH =*/ KH,
+ /*.KW =*/ KW,
+ /*.KHW =*/ KH * KW,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
+
+ GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
+
+ return 1;
+}
+
+int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
+
+ ggml_metal_kargs_conv_2d args = {
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.IW =*/ ne10,
+ /*.IH =*/ ne11,
+ /*.KW =*/ ne00,
+ /*.KH =*/ ne01,
+ /*.IC =*/ ne02,
+ /*.OC =*/ ne03,
+ /*.OW =*/ ne0,
+ /*.OH =*/ ne1,
+ /*.N =*/ ne3,
+ /*.s0 =*/ s0,
+ /*.s1 =*/ s1,
+ /*.p0 =*/ p0,
+ /*.p1 =*/ p1,
+ /*.d0 =*/ d0,
+ /*.d1 =*/ d1,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
+
+ int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+ nth = std::min(nth, 256);
+ nth = std::max(nth, 1);
+
+ const uint64_t n_out = ggml_nelements(op);
+
+ uint64_t tg = (n_out + nth - 1)/nth;
+ tg = std::max<uint64_t>(tg, 1);
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+
+ const int32_t IC = op->src[1]->ne[1];
+ const int32_t IL = op->src[1]->ne[0];
+
+ const int32_t K = op->src[0]->ne[0];
+
+ const int32_t OL = op->ne[0];
+ const int32_t OC = op->ne[1];
+
+ ggml_metal_kargs_conv_transpose_1d args = {
+ /*.IC =*/ IC,
+ /*.IL =*/ IL,
+ /*.K =*/ K,
+ /*.s0 =*/ s0,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
+
+ const int32_t IC = op->src[1]->ne[2];
+ const int32_t IH = op->src[1]->ne[1];
+ const int32_t IW = op->src[1]->ne[0];
+
+ const int32_t KH = op->src[0]->ne[1];
+ const int32_t KW = op->src[0]->ne[0];
+
+ const int32_t OW = op->ne[0];
+ const int32_t OH = op->ne[1];
+ const int32_t OC = op->ne[2];
+
+ ggml_metal_kargs_conv_transpose_2d args = {
+ /*.IC =*/ IC,
+ /*.IH =*/ IH,
+ /*.IW =*/ IW,
+ /*.KH =*/ KH,
+ /*.KW =*/ KW,
+ /*.OC =*/ OC,
+ /*.s0 =*/ s0,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ // Metal requires buffer size to be multiple of 16 bytes
+ const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const float sf0 = (float)ne0/op->src[0]->ne[0];
+ const float sf1 = (float)ne1/op->src[0]->ne[1];
+ const float sf2 = (float)ne2/op->src[0]->ne[2];
+ const float sf3 = (float)ne3/op->src[0]->ne[3];
+
+ ggml_metal_kargs_upscale args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.sf0 =*/ sf0,
+ /*.sf1 =*/ sf1,
+ /*.sf2 =*/ sf2,
+ /*.sf3 =*/ sf3
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_pad args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
+
+ const int nth = std::min(1024, ne0);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_pad_reflect_1d args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
+ /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
+
+ const int nth = std::min(1024, ne0);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ float start;
+ float step;
+
+ memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
+ memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float));
+
+ ggml_metal_kargs_arange args = {
+ /*.ne0 =*/ ne0,
+ /*.start =*/ start,
+ /*.step =*/ step
+ };
+
+ const int nth = std::min(1024, ne0);
+
+ auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const int dim = op->op_params[0];
+ const int max_period = op->op_params[1];
+
+ ggml_metal_kargs_timestep_embedding args = {
+ /*.nb1 =*/ nb1,
+ /*.dim =*/ dim,
+ /*.max_period =*/ max_period,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
+
+ const int nth = std::max(1, std::min(1024, dim/2));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_argmax args = {
+ /*.ne00 = */ ne00,
+ /*.nb01 = */ nb01,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
+
+ const int64_t nrows = ggml_nrows(op->src[0]);
+
+ int nth = 32; // SIMD width
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+
+ const size_t smem = pipeline.smem;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
+
+ // bitonic sort requires the number of elements to be power of 2
+ int nth = 1;
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ const int npr = (ne00 + nth - 1)/nth;
+
+ // Metal kernels require the buffer size to be multiple of 16 bytes
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_dst;
+ bid_tmp.offs += ggml_nbytes(op);
+
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
+ std::swap(bid_dst, bid_tmp);
+ }
+
+ ggml_metal_kargs_argsort args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ nth,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
+
+ auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
+
+ int len = nth;
+
+ while (len < ne00) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ ggml_metal_kargs_argsort_merge args_merge = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ ne00,
+ /*.len =*/ len,
+ };
+
+ // merges per row
+ const int nm = (ne00 + 2*len - 1) / (2*len);
+
+ const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+ std::swap(bid_dst, bid_tmp);
+
+ len <<= 1;
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
+
+ // bitonic sort requires the number of elements to be power of 2
+ int nth = 1;
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ // blocks per row
+ const int npr = (ne00 + nth - 1)/nth;
+
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_dst;
+ bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
+
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
+ std::swap(bid_dst, bid_tmp);
+ }
+
+ const int top_k = ne0;
+
+ ggml_metal_kargs_argsort args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
+ };
+
+ if (npr > 1) {
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
+ }
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
+
+ auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
+
+ int len = args.top_k;
+
+ while (len < args.ne0) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ // merges per row
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
+
+ const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
+
+ ggml_metal_kargs_argsort_merge args_merge = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ args.ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
+ /*.len =*/ len,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+ std::swap(bid_dst, bid_tmp);
+
+ len <<= 1;
+ }
+
+ return 1;
+}
+
+int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, ne00);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
+
+ const int64_t np = ggml_nelements(op->src[0]);
+ ggml_metal_kargs_opt_step_adamw args = {
+ /*.np =*/ np,
+ };
+
+ int ida = 0;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+ const int64_t n = (np + nth - 1) / nth;
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
+
+ const int64_t np = ggml_nelements(op->src[0]);
+ ggml_metal_kargs_opt_step_sgd args = {
+ /*.np =*/ np,
+ };
+
+ int ida = 0;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
+ const int64_t n = (np + nth - 1) / nth;
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
+int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+
+ {
+ ggml_metal_kargs_memset args = { /*.val =*/ 0 };
+
+ auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+ }
+
+ ggml_metal_op_concurrency_reset(ctx);
+
+ {
+ ggml_metal_kargs_count_equal args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
+
+ const size_t smem = pipeline.smem;
+
+ const int nth = 32*pipeline.nsg;
+
+ GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+ }
+
+ return 1;
+}