1#include "ggml-metal-ops.h"
2
3#include "ggml.h"
4#include "ggml-impl.h"
5#include "ggml-backend-impl.h"
6
7#include "ggml-metal-impl.h"
8#include "ggml-metal-common.h"
9#include "ggml-metal-device.h"
10
11#include <cassert>
12#include <algorithm>
13#include <limits>
14#include <cmath>
15
16static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
17 if (!t) {
18 return { nullptr, 0 };
19 }
20
21 ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
22
23 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;
24
25 return ggml_metal_buffer_get_id(ctx, t);
26}
27
28struct ggml_metal_op {
29 ggml_metal_op(
30 ggml_metal_device_t dev,
31 ggml_metal_cmd_buf_t cmd_buf,
32 ggml_cgraph * gf,
33 int idx_start,
34 int idx_end,
35 bool use_fusion,
36 bool use_concurrency,
37 bool use_capture,
38 int debug_graph,
39 int debug_fusion) {
40 this->dev = dev;
41 this->lib = ggml_metal_device_get_library(dev);
42 this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
43 this->mem_ranges = ggml_mem_ranges_init(debug_graph);
44 this->idx_start = idx_start;
45 this->idx_end = idx_end;
46 this->use_fusion = use_fusion;
47 this->use_concurrency = use_concurrency;
48 this->use_capture = use_capture;
49 this->debug_graph = debug_graph;
50 this->debug_fusion = debug_fusion;
51 this->gf = gf;
52
53 idxs.reserve(gf->n_nodes);
54
55 // filter empty nodes
56 // TODO: this can be removed when the allocator starts filtering them earlier
57 // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
58 for (int i = idx_start; i < idx_end; i++) {
59 if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
60 idxs.push_back(i);
61 }
62 }
63 }
64
65 ~ggml_metal_op() {
66 ggml_metal_encoder_end_encoding(this->enc);
67 ggml_metal_encoder_free(this->enc);
68 ggml_mem_ranges_free(this->mem_ranges);
69 }
70
71 int n_nodes() const {
72 return idxs.size();
73 }
74
75 ggml_tensor * node(int i) const {
76 assert(i >= 0 && i < (int) idxs.size());
77 return ggml_graph_node(gf, idxs[i]);
78 }
79
80 bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
81 assert(use_fusion);
82 assert(i0 >= 0 && i0 < n_nodes());
83
84 if (i0 + n_ops > n_nodes()) {
85 return false;
86 }
87
88 return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
89 }
90
91 ggml_metal_device_t dev;
92 ggml_metal_library_t lib;
93 ggml_metal_encoder_t enc;
94 ggml_mem_ranges_t mem_ranges;
95
96 bool use_fusion;
97 bool use_concurrency;
98 bool use_capture;
99
100 int debug_graph;
101 int debug_fusion;
102
103private:
104 ggml_cgraph * gf;
105
106 int idx_start;
107 int idx_end;
108
109 // non-empty node indices
110 std::vector<int> idxs;
111};
112
113ggml_metal_op_t ggml_metal_op_init(
114 ggml_metal_device_t dev,
115 ggml_metal_cmd_buf_t cmd_buf,
116 ggml_cgraph * gf,
117 int idx_start,
118 int idx_end,
119 bool use_fusion,
120 bool use_concurrency,
121 bool use_capture,
122 int debug_graph,
123 int debug_fusion) {
124 ggml_metal_op_t res = new ggml_metal_op(
125 dev,
126 cmd_buf,
127 gf,
128 idx_start,
129 idx_end,
130 use_fusion,
131 use_concurrency,
132 use_capture,
133 debug_graph,
134 debug_fusion);
135
136 return res;
137}
138
139void ggml_metal_op_free(ggml_metal_op_t ctx) {
140 delete ctx;
141}
142
143int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
144 return ctx->n_nodes();
145}
146
147static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
148 if (!ctx->mem_ranges) {
149 return true;
150 }
151
152 ggml_metal_encoder_memory_barrier(ctx->enc);
153
154 ggml_mem_ranges_reset(ctx->mem_ranges);
155
156 return true;
157}
158
159static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {
160 if (!ctx->mem_ranges) {
161 return false;
162 }
163
164 return ggml_mem_ranges_check(ctx->mem_ranges, node);
165}
166
167static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {
168 if (!ctx->mem_ranges) {
169 return true;
170 }
171
172 return ggml_mem_ranges_add(ctx->mem_ranges, node);
173}
174
175static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
176 struct ggml_tensor * node = ctx->node(idx);
177
178 //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
179
180 if (ggml_is_empty(node)) {
181 return 1;
182 }
183
184 switch (node->op) {
185 case GGML_OP_NONE:
186 case GGML_OP_RESHAPE:
187 case GGML_OP_VIEW:
188 case GGML_OP_TRANSPOSE:
189 case GGML_OP_PERMUTE:
190 {
191 // noop -> next node
192 if (ctx->debug_graph > 0) {
193 GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
194 }
195 } return 1;
196 default:
197 {
198 } break;
199 }
200
201 if (!ggml_metal_device_supports_op(ctx->dev, node)) {
202 GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node));
203 GGML_ABORT("unsupported op");
204 }
205
206 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
207 return 1;
208 }
209
210 int n_fuse = 1;
211
212 // check if the current node can run concurrently with other nodes before it
213 // the condition is that:
214 // - the current node cannot write to any previous src or dst ranges
215 // - the current node cannot read from any previous dst ranges
216 //
217 // if the condition is not satisfied, we put a memory barrier and clear all ranges
218 // otherwise, we add the new ranges to the encoding context and process the node concurrently
219 //
220 {
221 const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);
222
223 if (!is_concurrent) {
224 ggml_metal_op_concurrency_reset(ctx);
225 }
226
227 if (ctx->debug_graph > 0) {
228 GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
229 }
230 if (ctx->debug_graph > 1) {
231 GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
232 GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
233 GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
234 GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
235 GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
236 GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
237 GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
238 GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
239 GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
240 GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
241
242 if (node->src[0]) {
243 GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
244 ggml_is_contiguous(node->src[0]), node->src[0]->name);
245 }
246 if (node->src[1]) {
247 GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
248 ggml_is_contiguous(node->src[1]), node->src[1]->name);
249 }
250 if (node->src[2]) {
251 GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
252 ggml_is_contiguous(node->src[2]), node->src[2]->name);
253 }
254 if (node->src[3]) {
255 GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
256 ggml_is_contiguous(node->src[3]), node->src[3]->name);
257 }
258 if (node) {
259 GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
260 node->name);
261 }
262 }
263 }
264
265 switch (node->op) {
266 case GGML_OP_CONCAT:
267 {
268 n_fuse = ggml_metal_op_concat(ctx, idx);
269 } break;
270 case GGML_OP_ADD:
271 case GGML_OP_SUB:
272 case GGML_OP_MUL:
273 case GGML_OP_DIV:
274 {
275 n_fuse = ggml_metal_op_bin(ctx, idx);
276 } break;
277 case GGML_OP_ADD_ID:
278 {
279 n_fuse = ggml_metal_op_add_id(ctx, idx);
280 } break;
281 case GGML_OP_REPEAT:
282 {
283 n_fuse = ggml_metal_op_repeat(ctx, idx);
284 } break;
285 case GGML_OP_ACC:
286 {
287 n_fuse = ggml_metal_op_acc(ctx, idx);
288 } break;
289 case GGML_OP_SCALE:
290 case GGML_OP_FILL:
291 case GGML_OP_CLAMP:
292 case GGML_OP_LEAKY_RELU:
293 case GGML_OP_SQR:
294 case GGML_OP_SQRT:
295 case GGML_OP_SIN:
296 case GGML_OP_COS:
297 case GGML_OP_LOG:
298 case GGML_OP_UNARY:
299 {
300 n_fuse = ggml_metal_op_unary(ctx, idx);
301 } break;
302 case GGML_OP_GLU:
303 {
304 n_fuse = ggml_metal_op_glu(ctx, idx);
305 } break;
306 case GGML_OP_SUM:
307 {
308 n_fuse = ggml_metal_op_sum(ctx, idx);
309 } break;
310 case GGML_OP_SUM_ROWS:
311 case GGML_OP_MEAN:
312 {
313 n_fuse = ggml_metal_op_sum_rows(ctx, idx);
314 } break;
315 case GGML_OP_CUMSUM:
316 {
317 n_fuse = ggml_metal_op_cumsum(ctx, idx);
318 } break;
319 case GGML_OP_SOFT_MAX:
320 {
321 n_fuse = ggml_metal_op_soft_max(ctx, idx);
322 } break;
323 case GGML_OP_SSM_CONV:
324 {
325 n_fuse = ggml_metal_op_ssm_conv(ctx, idx);
326 } break;
327 case GGML_OP_SSM_SCAN:
328 {
329 n_fuse = ggml_metal_op_ssm_scan(ctx, idx);
330 } break;
331 case GGML_OP_RWKV_WKV6:
332 case GGML_OP_RWKV_WKV7:
333 {
334 n_fuse = ggml_metal_op_rwkv(ctx, idx);
335 } break;
336 case GGML_OP_SOLVE_TRI:
337 {
338 n_fuse = ggml_metal_op_solve_tri(ctx, idx);
339 } break;
340 case GGML_OP_MUL_MAT:
341 {
342 n_fuse = ggml_metal_op_mul_mat(ctx, idx);
343 } break;
344 case GGML_OP_MUL_MAT_ID:
345 {
346 n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);
347 } break;
348 case GGML_OP_GET_ROWS:
349 {
350 n_fuse = ggml_metal_op_get_rows(ctx, idx);
351 } break;
352 case GGML_OP_SET_ROWS:
353 {
354 n_fuse = ggml_metal_op_set_rows(ctx, idx);
355 } break;
356 case GGML_OP_DIAG:
357 {
358 n_fuse = ggml_metal_op_diag(ctx, idx);
359 } break;
360 case GGML_OP_L2_NORM:
361 {
362 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
363 } break;
364 case GGML_OP_GROUP_NORM:
365 {
366 n_fuse = ggml_metal_op_group_norm(ctx, idx);
367 } break;
368 case GGML_OP_NORM:
369 case GGML_OP_RMS_NORM:
370 {
371 n_fuse = ggml_metal_op_norm(ctx, idx);
372 } break;
373 case GGML_OP_ROPE:
374 {
375 n_fuse = ggml_metal_op_rope(ctx, idx);
376 } break;
377 case GGML_OP_IM2COL:
378 {
379 n_fuse = ggml_metal_op_im2col(ctx, idx);
380 } break;
381 case GGML_OP_CONV_2D:
382 {
383 n_fuse = ggml_metal_op_conv_2d(ctx, idx);
384 } break;
385 case GGML_OP_CONV_TRANSPOSE_1D:
386 {
387 n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
388 } break;
389 case GGML_OP_CONV_TRANSPOSE_2D:
390 {
391 n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
392 } break;
393 case GGML_OP_UPSCALE:
394 {
395 n_fuse = ggml_metal_op_upscale(ctx, idx);
396 } break;
397 case GGML_OP_PAD:
398 {
399 n_fuse = ggml_metal_op_pad(ctx, idx);
400 } break;
401 case GGML_OP_PAD_REFLECT_1D:
402 {
403 n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
404 } break;
405 case GGML_OP_ARANGE:
406 {
407 n_fuse = ggml_metal_op_arange(ctx, idx);
408 } break;
409 case GGML_OP_TIMESTEP_EMBEDDING:
410 {
411 n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);
412 } break;
413 case GGML_OP_ARGSORT:
414 {
415 n_fuse = ggml_metal_op_argsort(ctx, idx);
416 } break;
417 case GGML_OP_TOP_K:
418 {
419 n_fuse = ggml_metal_op_top_k(ctx, idx);
420 } break;
421 case GGML_OP_TRI:
422 {
423 n_fuse = ggml_metal_op_tri(ctx, idx);
424 } break;
425 case GGML_OP_FLASH_ATTN_EXT:
426 {
427 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
428 } break;
429 case GGML_OP_DUP:
430 case GGML_OP_CPY:
431 case GGML_OP_CONT:
432 {
433 n_fuse = ggml_metal_op_cpy(ctx, idx);
434 } break;
435 case GGML_OP_POOL_1D:
436 {
437 n_fuse = ggml_metal_op_pool_1d(ctx, idx);
438 } break;
439 case GGML_OP_POOL_2D:
440 {
441 n_fuse = ggml_metal_op_pool_2d(ctx, idx);
442 } break;
443 case GGML_OP_ARGMAX:
444 {
445 n_fuse = ggml_metal_op_argmax(ctx, idx);
446 } break;
447 case GGML_OP_OPT_STEP_ADAMW:
448 {
449 n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
450 } break;
451 case GGML_OP_OPT_STEP_SGD:
452 {
453 n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
454 } break;
455 case GGML_OP_COUNT_EQUAL:
456 {
457 n_fuse = ggml_metal_op_count_equal(ctx, idx);
458 } break;
459 default:
460 {
461 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
462 GGML_ABORT("fatal error");
463 }
464 }
465
466 if (ctx->debug_graph > 0) {
467 if (n_fuse > 1) {
468 GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
469 }
470 }
471
472 // update the mem ranges in the encoding context
473 for (int i = 0; i < n_fuse; ++i) {
474 if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
475 ggml_metal_op_concurrency_reset(ctx);
476 }
477 }
478
479 return n_fuse;
480}
481
482int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
483 if (ctx->use_capture) {
484 ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
485 }
486
487 int res = ggml_metal_op_encode_impl(ctx, idx);
488 if (idx + res > ctx->n_nodes()) {
489 GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
490 "https://github.com/ggml-org/llama.cpp/pull/14849");
491 }
492
493 if (ctx->use_capture) {
494 ggml_metal_encoder_debug_group_pop(ctx->enc);
495 }
496
497 return res;
498}
499
500int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
501 ggml_tensor * op = ctx->node(idx);
502
503 ggml_metal_library_t lib = ctx->lib;
504 ggml_metal_encoder_t enc = ctx->enc;
505
506 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
507 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
508 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
509 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
510 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
511 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
512
513 const int32_t dim = ((const int32_t *) op->op_params)[0];
514
515 ggml_metal_kargs_concat args = {
516 /*.ne00 =*/ ne00,
517 /*.ne01 =*/ ne01,
518 /*.ne02 =*/ ne02,
519 /*.ne03 =*/ ne03,
520 /*.nb00 =*/ nb00,
521 /*.nb01 =*/ nb01,
522 /*.nb02 =*/ nb02,
523 /*.nb03 =*/ nb03,
524 /*.ne10 =*/ ne10,
525 /*.ne11 =*/ ne11,
526 /*.ne12 =*/ ne12,
527 /*.ne13 =*/ ne13,
528 /*.nb10 =*/ nb10,
529 /*.nb11 =*/ nb11,
530 /*.nb12 =*/ nb12,
531 /*.nb13 =*/ nb13,
532 /*.ne0 =*/ ne0,
533 /*.ne1 =*/ ne1,
534 /*.ne2 =*/ ne2,
535 /*.ne3 =*/ ne3,
536 /*.nb0 =*/ nb0,
537 /*.nb1 =*/ nb1,
538 /*.nb2 =*/ nb2,
539 /*.nb3 =*/ nb3,
540 /*.dim =*/ dim,
541 };
542
543 auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
544
545 ggml_metal_encoder_set_pipeline(enc, pipeline);
546 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
547 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
548 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
549 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
550
551 const int nth = std::min(1024, ne0);
552
553 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
554
555 return 1;
556}
557
558int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
559 ggml_tensor * op = ctx->node(idx);
560
561 ggml_metal_library_t lib = ctx->lib;
562 ggml_metal_encoder_t enc = ctx->enc;
563
564 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
565 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
566 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
567 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
568
569 auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
570
571 ggml_metal_kargs_repeat args = {
572 /*.ne00 =*/ ne00,
573 /*.ne01 =*/ ne01,
574 /*.ne02 =*/ ne02,
575 /*.ne03 =*/ ne03,
576 /*.nb00 =*/ nb00,
577 /*.nb01 =*/ nb01,
578 /*.nb02 =*/ nb02,
579 /*.nb03 =*/ nb03,
580 /*.ne0 =*/ ne0,
581 /*.ne1 =*/ ne1,
582 /*.ne2 =*/ ne2,
583 /*.ne3 =*/ ne3,
584 /*.nb0 =*/ nb0,
585 /*.nb1 =*/ nb1,
586 /*.nb2 =*/ nb2,
587 /*.nb3 =*/ nb3,
588 };
589
590 ggml_metal_encoder_set_pipeline(enc, pipeline);
591 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
592 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
593 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
594
595 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
596
597 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
598
599 return 1;
600}
601
602int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
603 ggml_tensor * op = ctx->node(idx);
604
605 ggml_metal_library_t lib = ctx->lib;
606 ggml_metal_encoder_t enc = ctx->enc;
607
608 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
609 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
610 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
611 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
612 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
613 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
614
615 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
616 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
617 GGML_ASSERT(op->type == GGML_TYPE_F32);
618
619 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
620 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
621
622 const size_t pnb1 = ((const int32_t *) op->op_params)[0];
623 const size_t pnb2 = ((const int32_t *) op->op_params)[1];
624 const size_t pnb3 = ((const int32_t *) op->op_params)[2];
625 const size_t offs = ((const int32_t *) op->op_params)[3];
626
627 const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
628
629 if (!inplace) {
630 // run a separete kernel to cpy src->dst
631 // not sure how to avoid this
632 // TODO: make a simpler cpy_bytes kernel
633
634 //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
635 auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
636
637 ggml_metal_kargs_cpy args = {
638 /*.nk0 =*/ ne00,
639 /*.ne00 =*/ ne00,
640 /*.ne01 =*/ ne01,
641 /*.ne02 =*/ ne02,
642 /*.ne03 =*/ ne03,
643 /*.nb00 =*/ nb00,
644 /*.nb01 =*/ nb01,
645 /*.nb02 =*/ nb02,
646 /*.nb03 =*/ nb03,
647 /*.ne0 =*/ ne0,
648 /*.ne1 =*/ ne1,
649 /*.ne2 =*/ ne2,
650 /*.ne3 =*/ ne3,
651 /*.nb0 =*/ nb0,
652 /*.nb1 =*/ nb1,
653 /*.nb2 =*/ nb2,
654 /*.nb3 =*/ nb3,
655 };
656
657 ggml_metal_encoder_set_pipeline(enc, pipeline);
658 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
659 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
660 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
661
662 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
663
664 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
665
666 ggml_metal_op_concurrency_reset(ctx);
667 }
668
669 ggml_metal_kargs_bin args = {
670 /*.ne00 =*/ ne00,
671 /*.ne01 =*/ ne01,
672 /*.ne02 =*/ ne02,
673 /*.ne03 =*/ ne03,
674 /*.nb00 =*/ nb00,
675 /*.nb01 =*/ pnb1,
676 /*.nb02 =*/ pnb2,
677 /*.nb03 =*/ pnb3,
678 /*.ne10 =*/ ne10,
679 /*.ne11 =*/ ne11,
680 /*.ne12 =*/ ne12,
681 /*.ne13 =*/ ne13,
682 /*.nb10 =*/ nb10,
683 /*.nb11 =*/ nb11,
684 /*.nb12 =*/ nb12,
685 /*.nb13 =*/ nb13,
686 /*.ne0 =*/ ne0,
687 /*.ne1 =*/ ne1,
688 /*.ne2 =*/ ne2,
689 /*.ne3 =*/ ne3,
690 /*.nb0 =*/ nb0,
691 /*.nb1 =*/ pnb1,
692 /*.nb2 =*/ pnb2,
693 /*.nb3 =*/ pnb3,
694 /*.offs =*/ offs,
695 /*.o1 =*/ { 0 },
696 };
697
698 auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
699
700 ggml_metal_encoder_set_pipeline(enc, pipeline);
701 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
702 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
703 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
704 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
705
706 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
707
708 ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
709
710 return 1;
711}
712
713int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
714 ggml_tensor * op = ctx->node(idx);
715
716 ggml_metal_library_t lib = ctx->lib;
717 ggml_metal_encoder_t enc = ctx->enc;
718
719 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
720 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
721 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
722 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
723
724 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
725
726 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
727 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
728
729 ggml_metal_kargs_unary args = {
730 /*.ne00 =*/ ne00,
731 /*.ne01 =*/ ne01,
732 /*.ne02 =*/ ne02,
733 /*.ne03 =*/ ne03,
734 /*.nb00 =*/ nb00,
735 /*.nb01 =*/ nb01,
736 /*.nb02 =*/ nb02,
737 /*.nb03 =*/ nb03,
738 /*.ne0 =*/ ne0,
739 /*.ne1 =*/ ne1,
740 /*.ne2 =*/ ne2,
741 /*.ne3 =*/ ne3,
742 /*.nb0 =*/ nb0,
743 /*.nb1 =*/ nb1,
744 /*.nb2 =*/ nb2,
745 /*.nb3 =*/ nb3,
746 /*.slope =*/ 0.0,
747 /*.scale =*/ 0.0,
748 /*.bias =*/ 0.0,
749 /*.val =*/ 0.0,
750 /*.min =*/ 0.0,
751 /*.max =*/ 0.0,
752 };
753
754 if (op->op == GGML_OP_LEAKY_RELU) {
755 args.slope = ggml_get_op_params_f32(op, 0);
756 }
757
758 if (op->op == GGML_OP_SCALE) {
759 args.scale = ggml_get_op_params_f32(op, 0);
760 args.bias = ggml_get_op_params_f32(op, 1);
761 }
762
763 if (op->op == GGML_OP_FILL) {
764 args.val = ggml_get_op_params_f32(op, 0);
765 }
766
767 if (op->op == GGML_OP_CLAMP) {
768 args.min = ggml_get_op_params_f32(op, 0);
769 args.max = ggml_get_op_params_f32(op, 1);
770 }
771
772 auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
773
774 if (pipeline.c4) {
775 args.ne00 = ne00/4;
776 args.ne0 = ne0/4;
777 }
778
779 ggml_metal_encoder_set_pipeline(enc, pipeline);
780 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
781 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
782 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
783
784 if (pipeline.cnt) {
785 const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
786
787 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
788 } else {
789 const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
790
791 const int nth = MIN(args.ne00, nth_max);
792
793 const int nk0 = (args.ne00 + nth - 1)/nth;
794
795 ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
796 }
797
798 return 1;
799}
800
801int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
802 ggml_tensor * op = ctx->node(idx);
803
804 ggml_metal_library_t lib = ctx->lib;
805 ggml_metal_encoder_t enc = ctx->enc;
806
807 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
808 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
809 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
810 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
811 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
812 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
813
814 if (op->src[1]) {
815 GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
816 }
817
818 auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
819
820 const int32_t swp = ggml_get_op_params_i32(op, 1);
821 const float alpha = ggml_get_op_params_f32(op, 2);
822 const float limit = ggml_get_op_params_f32(op, 3);
823
824 const int32_t i00 = swp ? ne0 : 0;
825 const int32_t i10 = swp ? 0 : ne0;
826
827 ggml_metal_kargs_glu args = {
828 /*.ne00 =*/ ne00,
829 /*.nb01 =*/ nb01,
830 /*.ne10 =*/ op->src[1] ? ne10 : ne00,
831 /*.nb11 =*/ op->src[1] ? nb11 : nb01,
832 /*.ne0 =*/ ne0,
833 /*.nb1 =*/ nb1,
834 /*.i00 =*/ op->src[1] ? 0 : i00,
835 /*.i10 =*/ op->src[1] ? 0 : i10,
836 /*.alpha=*/ alpha,
837 /*.limit=*/ limit
838 };
839
840 const int64_t nrows = ggml_nrows(op->src[0]);
841
842 const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
843
844 ggml_metal_encoder_set_pipeline(enc, pipeline);
845 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
846 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
847 if (op->src[1]) {
848 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
849 } else {
850 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2);
851 }
852 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
853
854 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
855
856 return 1;
857}
858
859int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
860 ggml_tensor * op = ctx->node(idx);
861
862 ggml_metal_library_t lib = ctx->lib;
863 ggml_metal_encoder_t enc = ctx->enc;
864
865 const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
866
867 ggml_metal_kargs_sum args = {
868 /*.np =*/ n,
869 };
870
871 auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
872
873 int nth = 32; // SIMD width
874
875 while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
876 nth *= 2;
877 }
878
879 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
880 nth = std::min(nth, (int) n);
881
882 const int nsg = (nth + 31) / 32;
883
884 ggml_metal_encoder_set_pipeline(enc, pipeline);
885 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
886 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
887 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
888
889 ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
890
891 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
892
893 return 1;
894}
895
896int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
897 ggml_tensor * op = ctx->node(idx);
898
899 ggml_metal_library_t lib = ctx->lib;
900 ggml_metal_encoder_t enc = ctx->enc;
901
902 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
903 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
904 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
905 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
906
907 ggml_metal_kargs_sum_rows args = {
908 /*.ne00 =*/ ne00,
909 /*.ne01 =*/ ne01,
910 /*.ne02 =*/ ne02,
911 /*.ne03 =*/ ne03,
912 /*.nb00 =*/ nb00,
913 /*.nb01 =*/ nb01,
914 /*.nb02 =*/ nb02,
915 /*.nb03 =*/ nb03,
916 /*.ne0 =*/ ne0,
917 /*.ne1 =*/ ne1,
918 /*.ne2 =*/ ne2,
919 /*.ne3 =*/ ne3,
920 /*.nb0 =*/ nb0,
921 /*.nb1 =*/ nb1,
922 /*.nb2 =*/ nb2,
923 /*.nb3 =*/ nb3,
924 };
925
926 auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
927
928 int nth = 32; // SIMD width
929
930 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
931 nth *= 2;
932 }
933
934 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
935 nth = std::min(nth, ne00);
936
937 const size_t smem = pipeline.smem;
938
939 ggml_metal_encoder_set_pipeline(enc, pipeline);
940 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
941 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
942 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
943
944 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
945
946 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
947
948 return 1;
949}
950
951int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
952 ggml_tensor * op = ctx->node(idx);
953
954 ggml_metal_library_t lib = ctx->lib;
955 ggml_metal_encoder_t enc = ctx->enc;
956
957 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
958
959 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
960 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
961 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
962 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
963
964 auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
965
966 int nth = 1;
967 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
968 nth *= 2;
969 }
970
971 GGML_ASSERT(ne00 <= nth*nth);
972
973 const int64_t net0 = (ne00 + nth - 1) / nth;
974 const int64_t net1 = ne01;
975 const int64_t net2 = ne02;
976 const int64_t net3 = ne03;
977
978 const uint64_t nbt0 = sizeof(float);
979 const uint64_t nbt1 = net0*nbt0;
980 const uint64_t nbt2 = net1*nbt1;
981 const uint64_t nbt3 = net2*nbt2;
982
983 const size_t smem = GGML_PAD(32*sizeof(float), 16);
984
985 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
986 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
987
988 ggml_metal_buffer_id bid_tmp = bid_dst;
989 bid_tmp.offs += ggml_nbytes(op);
990
991 {
992 ggml_metal_kargs_cumsum_blk args = {
993 /*.ne00 =*/ ne00,
994 /*.ne01 =*/ ne01,
995 /*.ne02 =*/ ne02,
996 /*.ne03 =*/ ne03,
997 /*.nb00 =*/ nb00,
998 /*.nb01 =*/ nb01,
999 /*.nb02 =*/ nb02,
1000 /*.nb03 =*/ nb03,
1001 /*.net0 =*/ net0,
1002 /*.net1 =*/ net1,
1003 /*.net2 =*/ net2,
1004 /*.net3 =*/ net3,
1005 /*.nbt0 =*/ nbt0,
1006 /*.nbt1 =*/ nbt1,
1007 /*.nbt2 =*/ nbt2,
1008 /*.nbt3 =*/ nbt3,
1009 /*.outb =*/ ne00 > nth,
1010 };
1011
1012 ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1013 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1014 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1015 ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1016 ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1017
1018 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1019
1020 ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1021 }
1022
1023 if (ne00 > nth) {
1024 ggml_metal_op_concurrency_reset(ctx);
1025
1026 {
1027 ggml_metal_kargs_cumsum_blk args = {
1028 /*.ne00 =*/ net0,
1029 /*.ne01 =*/ net1,
1030 /*.ne02 =*/ net2,
1031 /*.ne03 =*/ net3,
1032 /*.nb00 =*/ nbt0,
1033 /*.nb01 =*/ nbt1,
1034 /*.nb02 =*/ nbt2,
1035 /*.nb03 =*/ nbt3,
1036 /*.net0 =*/ net0,
1037 /*.net1 =*/ net1,
1038 /*.net2 =*/ net2,
1039 /*.net3 =*/ net3,
1040 /*.nbt0 =*/ nbt0,
1041 /*.nbt1 =*/ nbt1,
1042 /*.nbt2 =*/ nbt2,
1043 /*.nbt3 =*/ nbt3,
1044 /*.outb =*/ false,
1045 };
1046
1047 ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1048 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1049 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1050 ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1051 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1052
1053 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1054
1055 ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1056 }
1057
1058 ggml_metal_op_concurrency_reset(ctx);
1059
1060 {
1061 auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1062
1063 ggml_metal_kargs_cumsum_add args = {
1064 /*.ne00 =*/ ne00,
1065 /*.ne01 =*/ ne01,
1066 /*.ne02 =*/ ne02,
1067 /*.ne03 =*/ ne03,
1068 /*.nb00 =*/ nb00,
1069 /*.nb01 =*/ nb01,
1070 /*.nb02 =*/ nb02,
1071 /*.nb03 =*/ nb03,
1072 /*.net0 =*/ net0,
1073 /*.net1 =*/ net1,
1074 /*.net2 =*/ net2,
1075 /*.net3 =*/ net3,
1076 /*.nbt0 =*/ nbt0,
1077 /*.nbt1 =*/ nbt1,
1078 /*.nbt2 =*/ nbt2,
1079 /*.nbt3 =*/ nbt3,
1080 };
1081
1082 ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1083 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1084 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1085 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1086
1087 ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1088 }
1089 }
1090
1091 return 1;
1092}
1093
1094int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
1095 ggml_tensor * op = ctx->node(idx);
1096
1097 ggml_metal_library_t lib = ctx->lib;
1098 ggml_metal_encoder_t enc = ctx->enc;
1099
1100 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1101 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1102 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1103 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1104 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1105 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1106
1107 auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1108
1109 ggml_metal_kargs_get_rows args = {
1110 /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1111 /*.ne00 =*/ ne00,
1112 /*.nb01 =*/ nb01,
1113 /*.nb02 =*/ nb02,
1114 /*.nb03 =*/ nb03,
1115 /*.ne10 =*/ ne10,
1116 /*.nb10 =*/ nb10,
1117 /*.nb11 =*/ nb11,
1118 /*.nb12 =*/ nb12,
1119 /*.nb1 =*/ nb1,
1120 /*.nb2 =*/ nb2,
1121 /*.nb3 =*/ nb3,
1122 };
1123
1124 const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1125
1126 const int nw0 = (args.ne00t + nth - 1)/nth;
1127
1128 ggml_metal_encoder_set_pipeline(enc, pipeline);
1129 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1130 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1131 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1132 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1133
1134 ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
1135
1136 return 1;
1137}
1138
1139int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1140 ggml_tensor * op = ctx->node(idx);
1141
1142 ggml_metal_library_t lib = ctx->lib;
1143 ggml_metal_encoder_t enc = ctx->enc;
1144
1145 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1146 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1147 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1148 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1149 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1150 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1151
1152 auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1153
1154 const int32_t nk0 = ne0/ggml_blck_size(op->type);
1155
1156 int nth = 32; // SIMD width
1157
1158 while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1159 nth *= 2;
1160 }
1161
1162 int nrptg = 1;
1163 if (nth > nk0) {
1164 nrptg = (nth + nk0 - 1)/nk0;
1165 nth = nk0;
1166
1167 if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1168 nrptg--;
1169 }
1170 }
1171
1172 nth = std::min(nth, nk0);
1173
1174 ggml_metal_kargs_set_rows args = {
1175 /*.nk0 =*/ nk0,
1176 /*.ne01 =*/ ne01,
1177 /*.nb01 =*/ nb01,
1178 /*.nb02 =*/ nb02,
1179 /*.nb03 =*/ nb03,
1180 /*.ne11 =*/ ne11,
1181 /*.ne12 =*/ ne12,
1182 /*.nb10 =*/ nb10,
1183 /*.nb11 =*/ nb11,
1184 /*.nb12 =*/ nb12,
1185 /*.nb1 =*/ nb1,
1186 /*.nb2 =*/ nb2,
1187 /*.nb3 =*/ nb3,
1188 };
1189
1190 ggml_metal_encoder_set_pipeline(enc, pipeline);
1191 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1192 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1193 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1194 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1195
1196 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1197
1198 return 1;
1199}
1200
1201int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1202 ggml_tensor * op = ctx->node(idx);
1203
1204 ggml_metal_library_t lib = ctx->lib;
1205 ggml_metal_encoder_t enc = ctx->enc;
1206
1207 GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
1208 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1209 GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
1210 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1211
1212 ggml_metal_kargs_diag args = {
1213 /*.ne00 =*/ne00,
1214 /*.ne01 =*/ne01,
1215 /*.ne02 =*/ne02,
1216 /*.ne03 =*/ne03,
1217 /*.nb00 =*/nb00,
1218 /*.nb01 =*/nb01,
1219 /*.nb02 =*/nb02,
1220 /*.nb03 =*/nb03,
1221 /*.ne0 =*/ne0,
1222 /*.ne1 =*/ne1,
1223 /*.ne2 =*/ne2,
1224 /*.ne3 =*/ne3,
1225 /*.nb0 =*/nb0,
1226 /*.nb1 =*/nb1,
1227 /*.nb2 =*/nb2,
1228 /*.nb3 =*/nb3,
1229 };
1230
1231 auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1232
1233 ggml_metal_encoder_set_pipeline(enc, pipeline);
1234 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1235 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1236 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
1237
1238 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1239
1240 return 1;
1241}
1242
1243int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1244 ggml_tensor * op = ctx->node(idx);
1245
1246 ggml_metal_library_t lib = ctx->lib;
1247 ggml_metal_encoder_t enc = ctx->enc;
1248
1249 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1250 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1251 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1252 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1253 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1254 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1255 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1256 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1257
1258 float scale;
1259 float max_bias;
1260
1261 memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
1262 memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
1263
1264 const uint32_t n_head = op->src[0]->ne[2];
1265 const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1266
1267 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1268 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1269
1270 // softmax
1271
1272 ggml_metal_kargs_soft_max args = {
1273 /*.ne00 =*/ ne00,
1274 /*.ne01 =*/ ne01,
1275 /*.ne02 =*/ ne02,
1276 /*.nb01 =*/ nb01,
1277 /*.nb02 =*/ nb02,
1278 /*.nb03 =*/ nb03,
1279 /*.ne11 =*/ ne11,
1280 /*.ne12 =*/ ne12,
1281 /*.ne13 =*/ ne13,
1282 /*.nb11 =*/ nb11,
1283 /*.nb12 =*/ nb12,
1284 /*.nb13 =*/ nb13,
1285 /*.nb1 =*/ nb1,
1286 /*.nb2 =*/ nb2,
1287 /*.nb3 =*/ nb3,
1288 /*.scale =*/ scale,
1289 /*.max_bias =*/ max_bias,
1290 /*.m0 =*/ m0,
1291 /*.m1 =*/ m1,
1292 /*.n_head_log2 =*/ n_head_log2,
1293 };
1294
1295 auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1296
1297 int nth = 32; // SIMD width
1298
1299 if (ne00%4 == 0) {
1300 while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1301 nth *= 2;
1302 }
1303 } else {
1304 while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1305 nth *= 2;
1306 }
1307 }
1308
1309 const size_t smem = pipeline.smem;
1310
1311 ggml_metal_encoder_set_pipeline(enc, pipeline);
1312 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1313 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1314 if (op->src[1]) {
1315 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1316 } else {
1317 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
1318 }
1319 if (op->src[2]) {
1320 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1321 } else {
1322 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
1323 }
1324 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
1325
1326 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1327
1328 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1329
1330 return 1;
1331}
1332
1333int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1334 ggml_tensor * op = ctx->node(idx);
1335
1336 ggml_metal_library_t lib = ctx->lib;
1337 ggml_metal_encoder_t enc = ctx->enc;
1338
1339 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1340 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1341 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1342 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1343 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1344 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1345
1346 ggml_metal_kargs_ssm_conv args = {
1347 /*.ne00 =*/ ne00,
1348 /*.ne01 =*/ ne01,
1349 /*.ne02 =*/ ne02,
1350 /*.nb00 =*/ nb00,
1351 /*.nb01 =*/ nb01,
1352 /*.nb02 =*/ nb02,
1353 /*.ne10 =*/ ne10,
1354 /*.ne11 =*/ ne11,
1355 /*.nb10 =*/ nb10,
1356 /*.nb11 =*/ nb11,
1357 /*.ne0 =*/ ne0,
1358 /*.ne1 =*/ ne1,
1359 /*.ne2 =*/ ne2,
1360 /*.nb0 =*/ nb0,
1361 /*.nb1 =*/ nb1,
1362 /*.nb2 =*/ nb2,
1363 };
1364
1365 // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1366 const bool use_batched = (ne1 > 1);
1367
1368 if (use_batched) {
1369 // Determine the smallest power of 2 that's >= ne1, but <= 256
1370 int BATCH_SIZE;
1371 if (ne1 > 128) BATCH_SIZE = 256;
1372 else if (ne1 > 64 ) BATCH_SIZE = 128;
1373 else if (ne1 > 32 ) BATCH_SIZE = 64;
1374 else if (ne1 > 16 ) BATCH_SIZE = 32;
1375 else if (ne1 > 8 ) BATCH_SIZE = 16;
1376 else if (ne1 > 4 ) BATCH_SIZE = 8;
1377 else BATCH_SIZE = 2;
1378
1379 auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1380
1381 ggml_metal_encoder_set_pipeline(enc, pipeline);
1382 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1383 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1384 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1385 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1386
1387 // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1388 // Each threadgroup has BATCH_SIZE threads, each handling one token
1389 const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1390 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1391 } else {
1392 auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1393
1394 ggml_metal_encoder_set_pipeline(enc, pipeline);
1395 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1396 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1397 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1398 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1399
1400 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1401 }
1402
1403 return 1;
1404}
1405
1406int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1407 ggml_tensor * op = ctx->node(idx);
1408
1409 ggml_metal_library_t lib = ctx->lib;
1410 ggml_metal_encoder_t enc = ctx->enc;
1411
1412 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1413 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1414 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1415 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1416 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1417 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1418 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1419 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1420 GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
1421 GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
1422 GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
1423 GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
1424 GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1425 GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1426 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1427 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1428
1429 const ggml_tensor * src3 = op->src[3];
1430 const ggml_tensor * src4 = op->src[4];
1431 const ggml_tensor * src5 = op->src[5];
1432 const ggml_tensor * src6 = op->src[6];
1433
1434 GGML_ASSERT(src3);
1435 GGML_ASSERT(src4);
1436 GGML_ASSERT(src5);
1437 GGML_ASSERT(src6);
1438
1439 const int64_t d_state = ne00;
1440 const int64_t d_inner = ne01;
1441 const int64_t n_head = ne02;
1442 const int64_t n_group = ne41;
1443 const int64_t n_seq_tokens = ne12;
1444 const int64_t n_seqs = ne13;
1445
1446 ggml_metal_kargs_ssm_scan args = {
1447 /*.d_state =*/ d_state,
1448 /*.d_inner =*/ d_inner,
1449 /*.n_head =*/ n_head,
1450 /*.n_group =*/ n_group,
1451 /*.n_seq_tokens =*/ n_seq_tokens,
1452 /*.n_seqs =*/ n_seqs,
1453 /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1454 /*.nb00 =*/ nb00,
1455 /*.nb01 =*/ nb01,
1456 /*.nb02 =*/ nb02,
1457 /*.nb03 =*/ nb03,
1458 /*.nb10 =*/ nb10,
1459 /*.nb11 =*/ nb11,
1460 /*.nb12 =*/ nb12,
1461 /*.ns12 =*/ nb12/nb10,
1462 /*.nb13 =*/ nb13,
1463 /*.nb20 =*/ nb20,
1464 /*.nb21 =*/ nb21,
1465 /*.ns21 =*/ nb21/nb20,
1466 /*.nb22 =*/ nb22,
1467 /*.ne30 =*/ ne30,
1468 /*.nb31 =*/ nb31,
1469 /*.nb41 =*/ nb41,
1470 /*.nb42 =*/ nb42,
1471 /*.ns42 =*/ nb42/nb40,
1472 /*.nb43 =*/ nb43,
1473 /*.nb51 =*/ nb51,
1474 /*.nb52 =*/ nb52,
1475 /*.ns52 =*/ nb52/nb50,
1476 /*.nb53 =*/ nb53,
1477 /*.nb0 =*/ nb0,
1478 };
1479
1480 auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1481
1482 GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1483
1484 const size_t smem = pipeline.smem;
1485
1486 ggml_metal_encoder_set_pipeline(enc, pipeline);
1487 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1488 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1489 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1490 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1491 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4);
1492 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5);
1493 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6);
1494 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1495 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
1496
1497 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1498
1499 ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1500
1501 return 1;
1502}
1503
1504int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1505 ggml_tensor * op = ctx->node(idx);
1506
1507 ggml_metal_library_t lib = ctx->lib;
1508 ggml_metal_encoder_t enc = ctx->enc;
1509
1510 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1511 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1512 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1513 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1514
1515 const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1516 const int64_t T = op->src[0]->ne[2];
1517 const int64_t C = op->ne[0];
1518 const int64_t H = op->src[0]->ne[1];
1519
1520 auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1521
1522 int ida = 0;
1523
1524 ggml_metal_encoder_set_pipeline(enc, pipeline);
1525 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1526 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1527 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1528 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1529 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1530 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1531 if (op->op == GGML_OP_RWKV_WKV7) {
1532 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1533 }
1534 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1535 ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1536 ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1537 ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1538 ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1539
1540 ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1541
1542 return 1;
1543}
1544
1545int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1546 ggml_tensor * op = ctx->node(idx);
1547
1548 ggml_metal_library_t lib = ctx->lib;
1549 ggml_metal_encoder_t enc = ctx->enc;
1550
1551 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1552 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1553 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1554 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1555 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1556 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1557
1558 ggml_metal_kargs_solve_tri args = {
1559 /*.ne00 =*/ ne00,
1560 /*.ne01 =*/ ne01,
1561 /*.ne02 =*/ ne02,
1562 /*.ne03 =*/ ne03,
1563 /*.nb00 =*/ nb00,
1564 /*.nb01 =*/ nb01,
1565 /*.nb02 =*/ nb02,
1566 /*.nb03 =*/ nb03,
1567 /*.ne10 =*/ ne10,
1568 /*.ne11 =*/ ne11,
1569 /*.ne12 =*/ ne12,
1570 /*.ne13 =*/ ne13,
1571 /*.nb10 =*/ nb10,
1572 /*.nb11 =*/ nb11,
1573 /*.nb12 =*/ nb12,
1574 /*.nb13 =*/ nb13,
1575 /*.ne0 =*/ ne0,
1576 /*.ne1 =*/ ne1,
1577 /*.ne2 =*/ ne2,
1578 /*.ne3 =*/ ne3,
1579 /*.nb0 =*/ nb0,
1580 /*.nb1 =*/ nb1,
1581 /*.nb2 =*/ nb2,
1582 /*.nb3 =*/ nb3,
1583 };
1584
1585 auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1586
1587 ggml_metal_encoder_set_pipeline(enc, pipeline);
1588 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1589 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1590 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1591 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1592
1593 const int nsg = pipeline.nsg;
1594
1595 ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1596
1597 ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1598
1599 return 1;
1600}
1601
1602int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1603 ggml_tensor * op = ctx->node(idx);
1604
1605 ggml_metal_library_t lib = ctx->lib;
1606 ggml_metal_encoder_t enc = ctx->enc;
1607
1608 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1609 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1610 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1611 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1612
1613 auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1614
1615 GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1616
1617 int64_t nk0 = ne00;
1618 if (ggml_is_quantized(op->src[0]->type)) {
1619 nk0 = ne00/16;
1620 } else if (ggml_is_quantized(op->type)) {
1621 nk0 = ne00/ggml_blck_size(op->type);
1622 }
1623
1624 int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1625
1626 // when rows are small, we can batch them together in a single threadgroup
1627 int nrptg = 1;
1628
1629 // TODO: relax this constraint in the future
1630 if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1631 if (nth > nk0) {
1632 nrptg = (nth + nk0 - 1)/nk0;
1633 nth = nk0;
1634
1635 if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1636 nrptg--;
1637 }
1638 }
1639 }
1640
1641 nth = std::min<int>(nth, nk0);
1642
1643 ggml_metal_kargs_cpy args = {
1644 /*.nk0 =*/ nk0,
1645 /*.ne00 =*/ ne00,
1646 /*.ne01 =*/ ne01,
1647 /*.ne02 =*/ ne02,
1648 /*.ne03 =*/ ne03,
1649 /*.nb00 =*/ nb00,
1650 /*.nb01 =*/ nb01,
1651 /*.nb02 =*/ nb02,
1652 /*.nb03 =*/ nb03,
1653 /*.ne0 =*/ ne0,
1654 /*.ne1 =*/ ne1,
1655 /*.ne2 =*/ ne2,
1656 /*.ne3 =*/ ne3,
1657 /*.nb0 =*/ nb0,
1658 /*.nb1 =*/ nb1,
1659 /*.nb2 =*/ nb2,
1660 /*.nb3 =*/ nb3,
1661 };
1662
1663 const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1664
1665 ggml_metal_encoder_set_pipeline(enc, pipeline);
1666 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1667 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1668 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1669
1670 ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1671
1672 return 1;
1673}
1674
1675int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1676 ggml_tensor * op = ctx->node(idx);
1677
1678 ggml_metal_library_t lib = ctx->lib;
1679 ggml_metal_encoder_t enc = ctx->enc;
1680
1681 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1682 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1683 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1684 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1685
1686 const int32_t * opts = op->op_params;
1687 ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1688
1689 const int32_t k0 = opts[1];
1690 const int32_t s0 = opts[2];
1691 const int32_t p0 = opts[3];
1692
1693 const int64_t IW = op->src[0]->ne[0];
1694 const int64_t OW = op->ne[0];
1695
1696 const int64_t np = ggml_nelements(op);
1697
1698 ggml_metal_kargs_pool_1d args_pool_1d = {
1699 /* .k0 = */ k0,
1700 /* .s0 = */ s0,
1701 /* .p0 = */ p0,
1702 /* .IW = */ IW,
1703 /* .OW = */ OW,
1704 /* .np = */ np
1705 };
1706
1707 auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1708
1709 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1710 const int ntg = (np + nth - 1) / nth;
1711
1712 ggml_metal_encoder_set_pipeline(enc, pipeline);
1713 ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1714 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1715 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1716
1717 ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1718
1719 return 1;
1720}
1721
1722
1723int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1724 ggml_tensor * op = ctx->node(idx);
1725
1726 ggml_metal_library_t lib = ctx->lib;
1727 ggml_metal_encoder_t enc = ctx->enc;
1728
1729 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1730 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1731 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1732 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1733
1734 const int32_t * opts = op->op_params;
1735 ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1736
1737 const int32_t k0 = opts[1];
1738 const int32_t k1 = opts[2];
1739 const int32_t s0 = opts[3];
1740 const int32_t s1 = opts[4];
1741 const int32_t p0 = opts[5];
1742 const int32_t p1 = opts[6];
1743
1744 const int64_t IH = op->src[0]->ne[1];
1745 const int64_t IW = op->src[0]->ne[0];
1746
1747 const int64_t N = op->ne[3];
1748 const int64_t OC = op->ne[2];
1749 const int64_t OH = op->ne[1];
1750 const int64_t OW = op->ne[0];
1751
1752 const int64_t np = N * OC * OH * OW;
1753
1754 ggml_metal_kargs_pool_2d args_pool_2d = {
1755 /* .k0 = */ k0,
1756 /* .k1 = */ k1,
1757 /* .s0 = */ s0,
1758 /* .s1 = */ s1,
1759 /* .p0 = */ p0,
1760 /* .p1 = */ p1,
1761 /* .IH = */ IH,
1762 /* .IW = */ IW,
1763 /* .OH = */ OH,
1764 /* .OW = */ OW,
1765 /* .np = */ np
1766 };
1767
1768 auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1769
1770 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1771 const int ntg = (np + nth - 1) / nth;
1772
1773 ggml_metal_encoder_set_pipeline(enc, pipeline);
1774 ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
1775 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1776 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1777
1778 ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1779
1780 return 1;
1781}
1782
1783int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1784 ggml_tensor * op = ctx->node(idx);
1785
1786 ggml_metal_library_t lib = ctx->lib;
1787 ggml_metal_encoder_t enc = ctx->enc;
1788
1789 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
1790
1791 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1792 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1793 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1794 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1795 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1796 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1797
1798 GGML_ASSERT(ne00 == ne10);
1799
1800 GGML_ASSERT(ne12 % ne02 == 0);
1801 GGML_ASSERT(ne13 % ne03 == 0);
1802
1803 const int16_t r2 = ne12/ne02;
1804 const int16_t r3 = ne13/ne03;
1805
1806 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1807 // to the matrix-vector kernel
1808 const int ne11_mm_min = 8;
1809
1810 // first try to use small-batch mat-mv kernels
1811 // these should be efficient for BS [2, ~8]
1812 if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&
1813 (
1814 (
1815 (
1816 op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1817 op->src[0]->type == GGML_TYPE_F16 ||
1818 op->src[0]->type == GGML_TYPE_Q4_0 ||
1819 op->src[0]->type == GGML_TYPE_Q4_1 ||
1820 op->src[0]->type == GGML_TYPE_Q5_0 ||
1821 op->src[0]->type == GGML_TYPE_Q5_1 ||
1822 op->src[0]->type == GGML_TYPE_Q8_0 ||
1823 op->src[0]->type == GGML_TYPE_MXFP4 ||
1824 op->src[0]->type == GGML_TYPE_IQ4_NL ||
1825 false) && (ne11 >= 2 && ne11 <= 8)
1826 ) ||
1827 (
1828 (
1829 op->src[0]->type == GGML_TYPE_Q4_K ||
1830 op->src[0]->type == GGML_TYPE_Q5_K ||
1831 op->src[0]->type == GGML_TYPE_Q6_K ||
1832 false) && (ne11 >= 4 && ne11 <= 8)
1833 )
1834 )
1835 ) {
1836 // TODO: determine the optimal parameters based on grid utilization
1837 // I still don't know why we should not always use the maximum available threads:
1838 //
1839 // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
1840 //
1841 // my current hypothesis is that the work grid is not evenly divisible for different nsg
1842 // values and there can be some tail effects when nsg is high. need to confirm this
1843 //
1844 const int nsg = 2; // num simdgroups per threadgroup
1845
1846 // num threads along row per simdgroup
1847 int16_t nxpsg = 0;
1848 if (ne00 % 256 == 0 && ne11 < 3) {
1849 nxpsg = 16;
1850 } else if (ne00 % 128 == 0) {
1851 nxpsg = 8;
1852 } else {
1853 nxpsg = 4;
1854 }
1855
1856 const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
1857 const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1858 int16_t r1ptg = 4; // num src1 rows per threadgroup
1859
1860 // note: not sure how optimal are those across all different hardware. there might be someting cleverer
1861 switch (ne11) {
1862 case 2:
1863 r1ptg = 2; break;
1864 case 3:
1865 case 6:
1866 r1ptg = 3; break;
1867 case 4:
1868 case 7:
1869 case 8:
1870 r1ptg = 4; break;
1871 case 5:
1872 r1ptg = 5; break;
1873 default:
1874 GGML_ABORT("unsupported ne11");
1875 };
1876
1877 auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1878
1879 ggml_metal_kargs_mul_mv_ext args = {
1880 /*.ne00 =*/ ne00,
1881 /*.ne01 =*/ ne01,
1882 /*.ne02 =*/ ne02,
1883 /*.nb00 =*/ nb00,
1884 /*.nb01 =*/ nb01,
1885 /*.nb02 =*/ nb02,
1886 /*.nb03 =*/ nb03,
1887 /*.ne10 =*/ ne10,
1888 /*.ne11 =*/ ne11,
1889 /*.ne12 =*/ ne12,
1890 /*.nb10 =*/ nb10,
1891 /*.nb11 =*/ nb11,
1892 /*.nb12 =*/ nb12,
1893 /*.nb13 =*/ nb13,
1894 /*.ne0 =*/ ne0,
1895 /*.ne1 =*/ ne1,
1896 /*.r2 =*/ r2,
1897 /*.r3 =*/ r3,
1898 };
1899
1900 ggml_metal_encoder_set_pipeline(enc, pipeline);
1901 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1902 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1903 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1904 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1905
1906 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
1907 } else if (
1908 !ggml_is_transposed(op->src[0]) &&
1909 !ggml_is_transposed(op->src[1]) &&
1910 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1911 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1912 props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1913 //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1914
1915 // some Metal matrix data types require aligned pointers
1916 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1917 //switch (op->src[0]->type) {
1918 // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1919 // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1920 // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1921 // default: break;
1922 //}
1923
1924 auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1925
1926 ggml_metal_kargs_mul_mm args = {
1927 /*.ne00 =*/ ne00,
1928 /*.ne02 =*/ ne02,
1929 /*.nb01 =*/ nb01,
1930 /*.nb02 =*/ nb02,
1931 /*.nb03 =*/ nb03,
1932 /*.ne12 =*/ ne12,
1933 /*.nb10 =*/ nb10,
1934 /*.nb11 =*/ nb11,
1935 /*.nb12 =*/ nb12,
1936 /*.nb13 =*/ nb13,
1937 /*.ne0 =*/ ne0,
1938 /*.ne1 =*/ ne1,
1939 /*.r2 =*/ r2,
1940 /*.r3 =*/ r3,
1941 };
1942
1943 ggml_metal_encoder_set_pipeline(enc, pipeline);
1944 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1945 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1946 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1947 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1948
1949 const size_t smem = pipeline.smem;
1950
1951 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1952 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1953 } else {
1954 auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1955
1956 const int nr0 = pipeline.nr0;
1957 const int nr1 = pipeline.nr1;
1958 const int nsg = pipeline.nsg;
1959
1960 const size_t smem = pipeline.smem;
1961
1962 ggml_metal_kargs_mul_mv args = {
1963 /*.ne00 =*/ ne00,
1964 /*.ne01 =*/ ne01,
1965 /*.ne02 =*/ ne02,
1966 /*.nb00 =*/ nb00,
1967 /*.nb01 =*/ nb01,
1968 /*.nb02 =*/ nb02,
1969 /*.nb03 =*/ nb03,
1970 /*.ne10 =*/ ne10,
1971 /*.ne11 =*/ ne11,
1972 /*.ne12 =*/ ne12,
1973 /*.nb10 =*/ nb10,
1974 /*.nb11 =*/ nb11,
1975 /*.nb12 =*/ nb12,
1976 /*.nb13 =*/ nb13,
1977 /*.ne0 =*/ ne0,
1978 /*.ne1 =*/ ne1,
1979 /*.nr0 =*/ nr0,
1980 /*.r2 =*/ r2,
1981 /*.r3 =*/ r3,
1982 };
1983
1984 ggml_metal_encoder_set_pipeline(enc, pipeline);
1985 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1986 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1987 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1988 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1989
1990 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1991
1992 if (op->src[0]->type == GGML_TYPE_F32 ||
1993 op->src[0]->type == GGML_TYPE_F16 ||
1994 op->src[0]->type == GGML_TYPE_BF16 ||
1995 op->src[0]->type == GGML_TYPE_Q8_0) {
1996 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1997 } else {
1998 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1999 }
2000 }
2001
2002 return 1;
2003}
2004
2005size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {
2006 assert(op->op == GGML_OP_MUL_MAT_ID);
2007
2008 const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2009
2010 return ggml_type_size(GGML_TYPE_I32)*ne02;
2011}
2012
2013size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
2014 assert(op->op == GGML_OP_MUL_MAT_ID);
2015
2016 const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2017 const int64_t ne21 = op->src[2]->ne[1]; // n_token
2018
2019 return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
2020}
2021
2022int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
2023 ggml_tensor * op = ctx->node(idx);
2024
2025 ggml_metal_library_t lib = ctx->lib;
2026 ggml_metal_encoder_t enc = ctx->enc;
2027
2028 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2029
2030 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2031 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2032 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2033 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2034 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2035 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2036 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2037 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2038
2039 // src2 = ids
2040 GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2041
2042 GGML_ASSERT(!ggml_is_transposed(op->src[0]));
2043 GGML_ASSERT(!ggml_is_transposed(op->src[1]));
2044
2045 GGML_ASSERT(ne03 == 1);
2046 GGML_ASSERT(ne13 == 1);
2047
2048 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2049 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2050 ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2051 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2052
2053 const uint32_t r2 = 1;
2054 const uint32_t r3 = 1;
2055
2056 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2057 // to the matrix-vector kernel
2058 // ne20 = n_used_experts
2059 // ne21 = n_rows (batch size)
2060 const int ne21_mm_id_min = 32;
2061
2062 if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
2063 // some Metal matrix data types require aligned pointers
2064 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2065 //switch (op->src[0]->type) {
2066 // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2067 // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
2068 // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
2069 // default: break;
2070 //}
2071
2072 // extra buffers for intermediate id mapping
2073 ggml_metal_buffer_id bid_tpe = bid_dst;
2074 bid_tpe.offs += ggml_nbytes(op);
2075
2076 ggml_metal_buffer_id bid_ids = bid_tpe;
2077 bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);
2078
2079 {
2080 ggml_metal_kargs_mul_mm_id_map0 args = {
2081 ne02,
2082 ne10,
2083 ne11, // n_expert_used (bcast)
2084 nb11,
2085 nb12,
2086 ne21, // n_tokens
2087 ne20, // n_expert_used
2088 nb21,
2089 };
2090
2091 auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
2092
2093 const size_t smem = pipeline.smem;
2094
2095 GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2096
2097 GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2098
2099 ggml_metal_encoder_set_pipeline(enc, pipeline);
2100 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2101 ggml_metal_encoder_set_buffer (enc, bid_src2, 1);
2102 ggml_metal_encoder_set_buffer (enc, bid_tpe, 2);
2103 ggml_metal_encoder_set_buffer (enc, bid_ids, 3);
2104
2105 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2106
2107 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
2108 }
2109
2110 // this barrier is always needed because the next kernel has to wait for the id maps to be computed
2111 ggml_metal_op_concurrency_reset(ctx);
2112
2113 {
2114 auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2115
2116 ggml_metal_kargs_mul_mm_id args = {
2117 /*.ne00 =*/ ne00,
2118 /*.ne02 =*/ ne02,
2119 /*.nb01 =*/ nb01,
2120 /*.nb02 =*/ nb02,
2121 /*.nb03 =*/ nb03,
2122 /*.ne11 =*/ ne11, // n_expert_used (bcast)
2123 /*.nb10 =*/ nb10,
2124 /*.nb11 =*/ nb11,
2125 /*.nb12 =*/ nb12,
2126 /*.nb13 =*/ nb13,
2127 /*.ne20 =*/ ne20, // n_expert_used
2128 /*.ne21 =*/ ne21, // n_tokens
2129 /*.ne0 =*/ ne0,
2130 /*.ne1 =*/ ne1,
2131 /*.r2 =*/ r2,
2132 /*.r3 =*/ r3,
2133 };
2134
2135 ggml_metal_encoder_set_pipeline(enc, pipeline);
2136 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2137 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2138 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2139 ggml_metal_encoder_set_buffer (enc, bid_tpe, 3);
2140 ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
2141 ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
2142
2143 const size_t smem = pipeline.smem;
2144
2145 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2146
2147 ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
2148 }
2149 } else {
2150 auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2151
2152 const int nr0 = pipeline.nr0;
2153 const int nr1 = pipeline.nr1;
2154 const int nsg = pipeline.nsg;
2155
2156 const size_t smem = pipeline.smem;
2157
2158 ggml_metal_kargs_mul_mv_id args = {
2159 /*.nei0 =*/ ne20,
2160 /*.nei1 =*/ ne21,
2161 /*.nbi1 =*/ nb21,
2162 /*.ne00 =*/ ne00,
2163 /*.ne01 =*/ ne01,
2164 /*.ne02 =*/ ne02,
2165 /*.nb00 =*/ nb00,
2166 /*.nb01 =*/ nb01,
2167 /*.nb02 =*/ nb02,
2168 /*.ne10 =*/ ne10,
2169 /*.ne11 =*/ ne11,
2170 /*.ne12 =*/ ne12,
2171 /*.ne13 =*/ ne13,
2172 /*.nb10 =*/ nb10,
2173 /*.nb11 =*/ nb11,
2174 /*.nb12 =*/ nb12,
2175 /*.ne0 =*/ ne0,
2176 /*.ne1 =*/ ne1,
2177 /*.nb1 =*/ nb1,
2178 /*.nr0 =*/ nr0,
2179 };
2180
2181 if (ggml_is_quantized(op->src[0]->type)) {
2182 GGML_ASSERT(ne00 >= nsg*nr0);
2183 }
2184
2185 ggml_metal_encoder_set_pipeline(enc, pipeline);
2186 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
2187 ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
2188 ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
2189 ggml_metal_encoder_set_buffer(enc, bid_dst, 3);
2190 ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
2191
2192 const int64_t _ne1 = 1;
2193 const int64_t ne123 = ne20*ne21;
2194
2195 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2196
2197 if (op->src[0]->type == GGML_TYPE_F32 ||
2198 op->src[0]->type == GGML_TYPE_F16 ||
2199 op->src[0]->type == GGML_TYPE_BF16 ||
2200 op->src[0]->type == GGML_TYPE_Q8_0) {
2201 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2202 } else {
2203 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2204 }
2205 }
2206
2207 return 1;
2208}
2209
2210int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
2211 ggml_tensor * op = ctx->node(idx);
2212
2213 ggml_metal_library_t lib = ctx->lib;
2214 ggml_metal_encoder_t enc = ctx->enc;
2215
2216 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2217 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2218 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2219 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2220 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2221 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2222 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2223
2224 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2225 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2226 GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2227 GGML_ASSERT(op->type == GGML_TYPE_F32);
2228
2229 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2230
2231 ggml_metal_kargs_add_id args = {
2232 /*.ne0 =*/ ne0,
2233 /*.ne1 =*/ ne1,
2234 /*.nb01 =*/ nb01,
2235 /*.nb02 =*/ nb02,
2236 /*.nb11 =*/ nb11,
2237 /*.nb21 =*/ nb21,
2238 };
2239
2240 auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2241
2242 ggml_metal_encoder_set_pipeline(enc, pipeline);
2243 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2244 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2245 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2246 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2247 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
2248
2249 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
2250
2251 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
2252
2253 return 1;
2254}
2255
2256bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
2257 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2258
2259 const int64_t ne00 = op->src[0]->ne[0]; // head size
2260 const int64_t ne01 = op->src[0]->ne[1]; // batch size
2261
2262 // use vec kernel if the batch size is small and if the head size is supported
2263 return (ne01 < 20) && (ne00 % 32 == 0);
2264}
2265
2266size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2267 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2268
2269 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2270 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2271 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2272 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2273 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2274 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2275 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2276 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2277
2278 size_t res = 0;
2279
2280 const bool has_mask = op->src[3] != nullptr;
2281
2282 // note: the non-vec kernel requires more extra memory, so always reserve for it
2283 GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2284
2285 //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2286 if (false) {
2287 // note: always reserve the padding space to avoid graph reallocations
2288 //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2289 const bool has_kvpad = true;
2290
2291 if (has_kvpad) {
2292 res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2293 nb11*ne12*ne13 +
2294 nb21*ne22*ne23 +
2295 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2296 }
2297 } else {
2298 //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2299 const bool has_kvpad = true;
2300
2301 if (has_kvpad) {
2302 res += OP_FLASH_ATTN_EXT_NCPSG*(
2303 nb11*ne12*ne13 +
2304 nb21*ne22*ne23 +
2305 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2306 }
2307 }
2308
2309 return res;
2310}
2311
2312size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2313 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2314
2315 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2316 //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2317 //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2318 //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2319 //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2320 //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2321 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2322 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2323
2324 size_t res = 0;
2325
2326 const bool has_mask = op->src[3] != nullptr;
2327
2328 if (!has_mask) {
2329 return res;
2330 }
2331
2332 const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2333
2334 // this optimization is not useful for the vector kernels
2335 // note: always reserve the blk buffer to avoid graph reallocations
2336 //if (is_vec) {
2337 // return res;
2338 //}
2339
2340 const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2341 const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2342
2343 const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2344 const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2345
2346 res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2347
2348 return res;
2349}
2350
2351size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
2352 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2353
2354 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2355 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2356 //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2357 //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2358 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2359 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2360 //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2361 //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2362
2363 size_t res = 0;
2364
2365 // note: always reserve the temp buffer to avoid graph reallocations
2366 //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2367 if (true) {
2368 const int64_t nwg = 32;
2369 const int64_t ne01_max = std::min(ne01, 32);
2370
2371 // temp buffer for writing the results from each workgroup
2372 // - ne20: the size of the Value head
2373 // - + 2: the S and M values for each intermediate result
2374 res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2375 }
2376
2377 return res;
2378}
2379
2380int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2381 ggml_tensor * op = ctx->node(idx);
2382
2383 ggml_metal_library_t lib = ctx->lib;
2384 ggml_metal_encoder_t enc = ctx->enc;
2385
2386 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2387
2388 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2389 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2390 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2391 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2392 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2393 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2394 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2395 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2396 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2397 GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
2398
2399 GGML_ASSERT(ne00 % 4 == 0);
2400
2401 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2402 GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2403
2404 //GGML_ASSERT(ggml_are_same_shape (src1, src2));
2405 GGML_ASSERT(ne11 == ne21);
2406 GGML_ASSERT(ne12 == ne22);
2407
2408 GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
2409 GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2410 "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
2411
2412 float scale;
2413 float max_bias;
2414 float logit_softcap;
2415
2416 memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
2417 memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
2418 memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
2419
2420 if (logit_softcap != 0.0f) {
2421 scale /= logit_softcap;
2422 }
2423
2424 const bool has_mask = op->src[3] != NULL;
2425 const bool has_sinks = op->src[4] != NULL;
2426 const bool has_bias = max_bias != 0.0f;
2427 const bool has_scap = logit_softcap != 0.0f;
2428
2429 const uint32_t n_head = op->src[0]->ne[2];
2430 const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2431
2432 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2433 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2434
2435 GGML_ASSERT(ne01 < 65536);
2436
2437 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2438 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2439 ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2440 ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2441 ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2442
2443 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2444
2445 ggml_metal_buffer_id bid_pad = bid_dst;
2446 bid_pad.offs += ggml_nbytes(op);
2447
2448 ggml_metal_buffer_id bid_blk = bid_pad;
2449 bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2450
2451 ggml_metal_buffer_id bid_tmp = bid_blk;
2452 bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2453
2454 if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2455 // half8x8 kernel
2456 const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2457 const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2458
2459 GGML_ASSERT(nqptg <= 32);
2460 GGML_ASSERT(nqptg % 8 == 0);
2461 GGML_ASSERT(ncpsg % 32 == 0);
2462
2463 bool need_sync = false;
2464
2465 const bool has_kvpad = ne11 % ncpsg != 0;
2466
2467 if (has_kvpad) {
2468 assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2469
2470 ggml_metal_kargs_flash_attn_ext_pad args0 = {
2471 /*.ne11 =*/ne11,
2472 /*.ne_12_2 =*/ne12,
2473 /*.ne_12_3 =*/ne13,
2474 /*.nb11 =*/nb11,
2475 /*.nb12 =*/nb12,
2476 /*.nb13 =*/nb13,
2477 /*.nb21 =*/nb21,
2478 /*.nb22 =*/nb22,
2479 /*.nb23 =*/nb23,
2480 /*.ne31 =*/ne31,
2481 /*.ne32 =*/ne32,
2482 /*.ne33 =*/ne33,
2483 /*.nb31 =*/nb31,
2484 /*.nb32 =*/nb32,
2485 /*.nb33 =*/nb33,
2486 };
2487
2488 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2489
2490 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2491 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2492 ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2493 ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2494 ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2495 ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2496
2497 assert(ne12 == ne22);
2498 assert(ne13 == ne23);
2499
2500 ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2501
2502 need_sync = true;
2503 }
2504
2505 if (has_mask) {
2506 assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2507
2508 ggml_metal_kargs_flash_attn_ext_blk args0 = {
2509 /*.ne01 =*/ ne01,
2510 /*.ne30 =*/ ne30,
2511 /*.ne31 =*/ ne31,
2512 /*.ne32 =*/ ne32,
2513 /*.ne33 =*/ ne33,
2514 /*.nb31 =*/ nb31,
2515 /*.nb32 =*/ nb32,
2516 /*.nb33 =*/ nb33,
2517 };
2518
2519 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2520
2521 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2522 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2523 ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2524 ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2525
2526 const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2527 const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2528
2529 ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2530
2531 need_sync = true;
2532 }
2533
2534 if (need_sync) {
2535 ggml_metal_op_concurrency_reset(ctx);
2536 }
2537
2538 const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
2539
2540 // 2*(2*ncpsg)
2541 // ncpsg soft_max values + ncpsg mask values
2542 //
2543 // 16*32*(nsg)
2544 // the shared memory needed for the simdgroups to load the KV cache
2545 // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
2546 //
2547#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
2548
2549 //int64_t nsgmax = 4;
2550 //
2551 //if (is_q) {
2552 // nsgmax = 2;
2553 // while (true) {
2554 // const size_t smem = FATTN_SMEM(nsgmax);
2555 // if (smem > props_dev->max_theadgroup_memory_size) {
2556 // break;
2557 // }
2558 // nsgmax *= 2;
2559 // }
2560 // nsgmax /= 2;
2561 //}
2562
2563 // simdgroups per threadgroup (a.k.a. warps)
2564 //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2565 int32_t nsg = ne00 >= 512 ? 8 : 4;
2566
2567 const size_t smem = FATTN_SMEM(nsg);
2568
2569 ggml_metal_kargs_flash_attn_ext args = {
2570 /*.ne01 =*/ ne01,
2571 /*.ne02 =*/ ne02,
2572 /*.ne03 =*/ ne03,
2573 /*.nb01 =*/ nb01,
2574 /*.nb02 =*/ nb02,
2575 /*.nb03 =*/ nb03,
2576 /*.ne11 =*/ ne11,
2577 /*.ne_12_2 =*/ ne12,
2578 /*.ne_12_3 =*/ ne13,
2579 /*.ns10 =*/ int32_t(nb11/nb10),
2580 /*.nb11 =*/ nb11,
2581 /*.nb12 =*/ nb12,
2582 /*.nb13 =*/ nb13,
2583 /*.ns20 =*/ int32_t(nb21/nb20),
2584 /*.nb21 =*/ nb21,
2585 /*.nb22 =*/ nb22,
2586 /*.nb23 =*/ nb23,
2587 /*.ne31 =*/ ne31,
2588 /*.ne32 =*/ ne32,
2589 /*.ne33 =*/ ne33,
2590 /*.nb31 =*/ nb31,
2591 /*.nb32 =*/ nb32,
2592 /*.nb33 =*/ nb33,
2593 /*.ne1 =*/ ne1,
2594 /*.ne2 =*/ ne2,
2595 /*.ne3 =*/ ne3,
2596 /*.scale =*/ scale,
2597 /*.max_bias =*/ max_bias,
2598 /*.m0 =*/ m0,
2599 /*.m1 =*/ m1,
2600 /*.n_head_log2 =*/ n_head_log2,
2601 /*.logit_softcap =*/ logit_softcap,
2602 };
2603
2604 auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2605
2606 ggml_metal_encoder_set_pipeline(enc, pipeline);
2607 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2608 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2609 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2610 ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2611 ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2612 ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2613 ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2614 ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2615 ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2616
2617 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2618
2619 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
2620#undef FATTN_SMEM
2621 } else {
2622 // half4x4 kernel
2623 const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2624 const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2625 const int nhptg = 1; // heads per threadgroup
2626
2627 GGML_ASSERT(nqptg <= 32);
2628 GGML_ASSERT(nqptg % 1 == 0);
2629 GGML_ASSERT(ncpsg % 32 == 0);
2630
2631 bool need_sync = false;
2632
2633 const bool has_kvpad = ne11 % ncpsg != 0;
2634
2635 if (has_kvpad) {
2636 assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2637
2638 ggml_metal_kargs_flash_attn_ext_pad args0 = {
2639 /*.ne11 =*/ne11,
2640 /*.ne_12_2 =*/ne12,
2641 /*.ne_12_3 =*/ne13,
2642 /*.nb11 =*/nb11,
2643 /*.nb12 =*/nb12,
2644 /*.nb13 =*/nb13,
2645 /*.nb21 =*/nb21,
2646 /*.nb22 =*/nb22,
2647 /*.nb23 =*/nb23,
2648 /*.ne31 =*/ne31,
2649 /*.ne32 =*/ne32,
2650 /*.ne33 =*/ne33,
2651 /*.nb31 =*/nb31,
2652 /*.nb32 =*/nb32,
2653 /*.nb33 =*/nb33,
2654 };
2655
2656 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2657
2658 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2659 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2660 ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2661 ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2662 ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2663 ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2664
2665 assert(ne12 == ne22);
2666 assert(ne13 == ne23);
2667
2668 ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2669
2670 need_sync = true;
2671 }
2672
2673 if (need_sync) {
2674 ggml_metal_op_concurrency_reset(ctx);
2675 }
2676
2677 // note: for simplicity assume the K is larger or equal than V
2678 GGML_ASSERT(ne10 >= ne20);
2679
2680 // ne00 + 2*ncpsg*(nsg)
2681 // for each query, we load it as f16 in shared memory (ne00)
2682 // and store the soft_max values and the mask
2683 //
2684 // ne20*(nsg)
2685 // each simdgroup has a full f32 head vector in shared mem to accumulate results
2686 //
2687#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2688
2689 int64_t nsg = 1;
2690
2691 // workgroups
2692 // each workgroup handles nsg*nkpsg cache values
2693 int32_t nwg = 1;
2694 if (false) {
2695 // for small KV caches, we could launch a single workgroup and write the results directly to dst/
2696 // however, this does not lead to significant improvement, so disabled
2697 nwg = 1;
2698 nsg = 4;
2699 } else {
2700 nwg = 32;
2701 nsg = 1;
2702 while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2703 nsg *= 2;
2704 }
2705 }
2706
2707 ggml_metal_kargs_flash_attn_ext_vec args = {
2708 /*.ne01 =*/ ne01,
2709 /*.ne02 =*/ ne02,
2710 /*.ne03 =*/ ne03,
2711 /*.nb01 =*/ nb01,
2712 /*.nb02 =*/ nb02,
2713 /*.nb03 =*/ nb03,
2714 /*.ne11 =*/ ne11,
2715 /*.ne_12_2 =*/ ne12,
2716 /*.ne_12_3 =*/ ne13,
2717 /*.ns10 =*/ int32_t(nb11/nb10),
2718 /*.nb11 =*/ nb11,
2719 /*.nb12 =*/ nb12,
2720 /*.nb13 =*/ nb13,
2721 /*.ns20 =*/ int32_t(nb21/nb20),
2722 /*.nb21 =*/ nb21,
2723 /*.nb22 =*/ nb22,
2724 /*.nb23 =*/ nb23,
2725 /*.ne31 =*/ ne31,
2726 /*.ne32 =*/ ne32,
2727 /*.ne33 =*/ ne33,
2728 /*.nb31 =*/ nb31,
2729 /*.nb32 =*/ nb32,
2730 /*.nb33 =*/ nb33,
2731 /*.ne1 =*/ ne1,
2732 /*.ne2 =*/ ne2,
2733 /*.ne3 =*/ ne3,
2734 /*.scale =*/ scale,
2735 /*.max_bias =*/ max_bias,
2736 /*.m0 =*/ m0,
2737 /*.m1 =*/ m1,
2738 /*.n_head_log2 =*/ n_head_log2,
2739 /*.logit_softcap =*/ logit_softcap,
2740 };
2741
2742 auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2743
2744 GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2745
2746 ggml_metal_encoder_set_pipeline(enc, pipeline);
2747 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2748 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2749 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2750 ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2751 ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2752 ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2753
2754 const size_t smem = FATTN_SMEM(nsg);
2755
2756 //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2757 GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2758
2759 if (nwg == 1) {
2760 assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2761
2762 // using 1 workgroup -> write the result directly into dst
2763 ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2764 ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2765
2766 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2767
2768 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2769 } else {
2770 // sanity checks
2771 assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2772
2773 GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2774 GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2775
2776 // write the results from each workgroup into a temp buffer
2777 ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2778 ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2779
2780 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2781 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2782
2783 // sync the 2 kernels
2784 ggml_metal_op_concurrency_reset(ctx);
2785
2786 // reduce the results from the workgroups
2787 {
2788 const int32_t nrows = ne1*ne2*ne3;
2789
2790 ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
2791 nrows,
2792 };
2793
2794 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2795
2796 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2797 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2798 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
2799 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2800
2801 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
2802 }
2803 }
2804#undef FATTN_SMEM
2805 }
2806
2807 return 1;
2808}
2809
2810int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2811 ggml_tensor * op = ctx->node(idx);
2812
2813 ggml_metal_library_t lib = ctx->lib;
2814 ggml_metal_encoder_t enc = ctx->enc;
2815
2816 const bool use_fusion = ctx->use_fusion;
2817
2818 const int debug_fusion = ctx->debug_fusion;
2819
2820 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2821 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2822 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2823 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2824 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2825 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2826
2827 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2828 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2829
2830 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2831 GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2832
2833 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2834 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2835 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2836
2837 ggml_metal_kargs_bin args = {
2838 /*.ne00 =*/ ne00,
2839 /*.ne01 =*/ ne01,
2840 /*.ne02 =*/ ne02,
2841 /*.ne03 =*/ ne03,
2842 /*.nb00 =*/ nb00,
2843 /*.nb01 =*/ nb01,
2844 /*.nb02 =*/ nb02,
2845 /*.nb03 =*/ nb03,
2846 /*.ne10 =*/ ne10,
2847 /*.ne11 =*/ ne11,
2848 /*.ne12 =*/ ne12,
2849 /*.ne13 =*/ ne13,
2850 /*.nb10 =*/ nb10,
2851 /*.nb11 =*/ nb11,
2852 /*.nb12 =*/ nb12,
2853 /*.nb13 =*/ nb13,
2854 /*.ne0 =*/ ne0,
2855 /*.ne1 =*/ ne1,
2856 /*.ne2 =*/ ne2,
2857 /*.ne3 =*/ ne3,
2858 /*.nb0 =*/ nb0,
2859 /*.nb1 =*/ nb1,
2860 /*.nb2 =*/ nb2,
2861 /*.nb3 =*/ nb3,
2862 /*.offs =*/ 0,
2863 /*.o1 =*/ { bid_src1.offs },
2864 };
2865
2866 ggml_op fops[8];
2867
2868 int n_fuse = 1;
2869
2870 // c[0] = add(a, b[0])
2871 // c[1] = add(c[0], b[1])
2872 // c[2] = add(c[1], b[2])
2873 // ...
2874 if (use_fusion) {
2875 fops[0] = GGML_OP_ADD;
2876 fops[1] = GGML_OP_ADD;
2877 fops[2] = GGML_OP_ADD;
2878 fops[3] = GGML_OP_ADD;
2879 fops[4] = GGML_OP_ADD;
2880 fops[5] = GGML_OP_ADD;
2881 fops[6] = GGML_OP_ADD;
2882 fops[7] = GGML_OP_ADD;
2883
2884 // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
2885 // across splits. idx_end indicates the last node in the current split
2886 for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2887 if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2888 break;
2889 }
2890
2891 ggml_tensor * f0 = ctx->node(idx + n_fuse);
2892 ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2893
2894 if (f0 != f1->src[0]) {
2895 break;
2896 }
2897
2898 // b[0] === b[1] === ...
2899 if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
2900 break;
2901 }
2902
2903 // only fuse ops if src1 is in the same Metal buffer
2904 ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
2905 if (bid_fuse.metal != bid_src1.metal) {
2906 break;
2907 }
2908
2909 //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
2910
2911 args.o1[n_fuse + 1] = bid_fuse.offs;
2912 }
2913
2914 ++n_fuse;
2915
2916 if (debug_fusion > 1 && n_fuse > 1) {
2917 GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2918 }
2919 }
2920
2921 // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2922 bid_src1.offs = 0;
2923
2924 struct ggml_metal_pipeline_with_params pipeline;
2925
2926 pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2927
2928 if (n_fuse > 1) {
2929 bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2930
2931 for (int i = 1; i < n_fuse; ++i) {
2932 if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2933 ggml_metal_op_concurrency_reset(ctx);
2934
2935 break;
2936 }
2937 }
2938 }
2939
2940 if (pipeline.c4) {
2941 args.ne00 = ne00/4;
2942 args.ne10 = ne10/4;
2943 args.ne0 = ne0/4;
2944 }
2945
2946 ggml_metal_encoder_set_pipeline(enc, pipeline);
2947 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2948 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2949 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2950 ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2951
2952 if (pipeline.cnt) {
2953 const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
2954
2955 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
2956 } else {
2957 const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2958
2959 int nth = 1;
2960
2961 while (2*nth < args.ne0 && nth < nth_max) {
2962 nth *= 2;
2963 }
2964
2965 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2966 }
2967
2968 return n_fuse;
2969}
2970
2971int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2972 ggml_tensor * op = ctx->node(idx);
2973
2974 ggml_metal_library_t lib = ctx->lib;
2975 ggml_metal_encoder_t enc = ctx->enc;
2976
2977 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2978 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2979 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2980 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2981
2982 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2983
2984 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2985 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2986
2987 float eps;
2988 memcpy(&eps, op->op_params, sizeof(float));
2989
2990 ggml_metal_kargs_l2_norm args = {
2991 /*.ne00 =*/ ne00,
2992 /*.ne01 =*/ ne01,
2993 /*.ne02 =*/ ne02,
2994 /*.ne03 =*/ ne03,
2995 /*.nb00 =*/ nb00,
2996 /*.nb01 =*/ nb01,
2997 /*.nb02 =*/ nb02,
2998 /*.nb03 =*/ nb03,
2999 /*.ne0 =*/ ne0,
3000 /*.ne1 =*/ ne1,
3001 /*.ne2 =*/ ne2,
3002 /*.ne3 =*/ ne3,
3003 /*.nb0 =*/ nb0,
3004 /*.nb1 =*/ nb1,
3005 /*.nb2 =*/ nb2,
3006 /*.nb3 =*/ nb3,
3007 /*.eps =*/ eps,
3008 };
3009
3010 auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
3011
3012 if (pipeline.c4) {
3013 args.ne00 = ne00/4;
3014 args.ne0 = ne0/4;
3015 }
3016
3017 int nth = 32; // SIMD width
3018
3019 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3020 nth *= 2;
3021 }
3022
3023 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3024
3025 const size_t smem = pipeline.smem;
3026
3027 ggml_metal_encoder_set_pipeline(enc, pipeline);
3028 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3029 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3030 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3031
3032 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3033
3034 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3035
3036 return 1;
3037}
3038
3039int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
3040 ggml_tensor * op = ctx->node(idx);
3041
3042 ggml_metal_library_t lib = ctx->lib;
3043 ggml_metal_encoder_t enc = ctx->enc;
3044
3045 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3046 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3047 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3048 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3049
3050 const int32_t ngrp = ((const int32_t *) op->op_params)[0];
3051
3052 float eps;
3053 memcpy(&eps, op->op_params + 1, sizeof(float));
3054
3055 ggml_metal_kargs_group_norm args = {
3056 /*.ne00 =*/ ne00,
3057 /*.ne01 =*/ ne01,
3058 /*.ne02 =*/ ne02,
3059 /*.nb00 =*/ nb00,
3060 /*.nb01 =*/ nb01,
3061 /*.nb02 =*/ nb02,
3062 /*.ngrp =*/ ngrp,
3063 /*.eps =*/ eps,
3064 };
3065
3066 auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
3067
3068 int nth = 32; // SIMD width
3069 //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3070 // nth *= 2;
3071 //}
3072
3073 //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3074 //nth = std::min(nth, ne00/4);
3075
3076 const size_t smem = pipeline.smem;
3077
3078 ggml_metal_encoder_set_pipeline(enc, pipeline);
3079 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3080 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3081 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3082
3083 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3084
3085 ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
3086
3087 return 1;
3088}
3089
3090int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
3091 ggml_tensor * op = ctx->node(idx);
3092
3093 ggml_metal_library_t lib = ctx->lib;
3094 ggml_metal_encoder_t enc = ctx->enc;
3095
3096 const bool use_fusion = ctx->use_fusion;
3097
3098 const int debug_fusion = ctx->debug_fusion;
3099
3100 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3101 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3102 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3103 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3104
3105 float eps;
3106 memcpy(&eps, op->op_params, sizeof(float));
3107
3108 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3109 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3110
3111 ggml_metal_kargs_norm args = {
3112 /*.ne00 =*/ ne00,
3113 /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
3114 /*.nb1 =*/ nb1,
3115 /*.nb2 =*/ nb2,
3116 /*.nb3 =*/ nb3,
3117 /*.eps =*/ eps,
3118 /*.nef1 =*/ { ne01 },
3119 /*.nef2 =*/ { ne02 },
3120 /*.nef3 =*/ { ne03 },
3121 /*.nbf1 =*/ { nb01 },
3122 /*.nbf2 =*/ { nb02 },
3123 /*.nbf3 =*/ { nb03 },
3124 };
3125
3126 ggml_op fops[8];
3127
3128 int n_fuse = 1;
3129
3130 ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
3131
3132 // d[0] = norm(a)
3133 // d[1] = mul(d[0], b)
3134 // d[2] = add(d[1], c)
3135 if (use_fusion) {
3136 fops[0] = op->op;
3137 fops[1] = GGML_OP_MUL;
3138 fops[2] = GGML_OP_ADD;
3139
3140 for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
3141 if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
3142 break;
3143 }
3144
3145 ggml_tensor * f0 = ctx->node(idx + n_fuse);
3146 ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
3147
3148 if (f0 != f1->src[0]) {
3149 break;
3150 }
3151
3152 if (f1->src[1]->ne[0] != op->ne[0]) {
3153 break;
3154 }
3155
3156 if (!ggml_is_contiguous_rows(f1->src[1])) {
3157 break;
3158 }
3159
3160 if (f1->type != GGML_TYPE_F32) {
3161 break;
3162 }
3163
3164 //ctx->fuse_cnt[f1->op]++;
3165
3166 bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
3167
3168 args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
3169 args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
3170 args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
3171
3172 args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
3173 args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
3174 args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
3175 }
3176
3177 ++n_fuse;
3178
3179 if (debug_fusion > 1 && n_fuse > 1) {
3180 if (n_fuse == 2) {
3181 GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
3182 }
3183 if (n_fuse == 3) {
3184 GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
3185 }
3186 }
3187 }
3188
3189 if (n_fuse > 1) {
3190 bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
3191
3192 for (int i = 1; i < n_fuse; ++i) {
3193 if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
3194 ggml_metal_op_concurrency_reset(ctx);
3195
3196 break;
3197 }
3198 }
3199 }
3200
3201 auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3202
3203 int nth = 32; // SIMD width
3204
3205 while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3206 nth *= 2;
3207 }
3208
3209 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3210 nth = std::min(nth, args.ne00_t);
3211
3212 const size_t smem = pipeline.smem;
3213
3214 ggml_metal_encoder_set_pipeline(enc, pipeline);
3215 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3216 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3217 ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
3218 ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
3219 ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
3220
3221 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3222
3223 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3224
3225 return n_fuse;
3226}
3227
3228int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
3229 ggml_tensor * op = ctx->node(idx);
3230
3231 ggml_metal_library_t lib = ctx->lib;
3232 ggml_metal_encoder_t enc = ctx->enc;
3233
3234 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3235 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3236 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3237 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3238 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3239 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3240
3241 // make sure we have one or more position id(ne10) per token(ne02)
3242 GGML_ASSERT(ne10 % ne02 == 0);
3243 GGML_ASSERT(ne10 >= ne02);
3244
3245 const int nth = std::min(1024, ne00);
3246
3247 const int n_past = ((const int32_t *) op->op_params)[0];
3248 const int n_dims = ((const int32_t *) op->op_params)[1];
3249 //const int mode = ((const int32_t *) op->op_params)[2];
3250 // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
3251 const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
3252
3253 float freq_base;
3254 float freq_scale;
3255 float ext_factor;
3256 float attn_factor;
3257 float beta_fast;
3258 float beta_slow;
3259
3260 memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float));
3261 memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float));
3262 memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float));
3263 memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float));
3264 memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float));
3265 memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float));
3266
3267 // mrope
3268 const int sect_0 = ((const int32_t *) op->op_params)[11];
3269 const int sect_1 = ((const int32_t *) op->op_params)[12];
3270 const int sect_2 = ((const int32_t *) op->op_params)[13];
3271 const int sect_3 = ((const int32_t *) op->op_params)[14];
3272
3273 ggml_metal_kargs_rope args = {
3274 /*.ne00 =*/ ne00,
3275 /*.ne01 =*/ ne01,
3276 /*.ne02 =*/ ne02,
3277 /*.ne03 =*/ ne03,
3278 /*.nb00 =*/ nb00,
3279 /*.nb01 =*/ nb01,
3280 /*.nb02 =*/ nb02,
3281 /*.nb03 =*/ nb03,
3282 /*.ne0 =*/ ne0,
3283 /*.ne1 =*/ ne1,
3284 /*.ne2 =*/ ne2,
3285 /*.ne3 =*/ ne3,
3286 /*.nb0 =*/ nb0,
3287 /*.nb1 =*/ nb1,
3288 /*.nb2 =*/ nb2,
3289 /*.nb3 =*/ nb3,
3290 /*.n_past =*/ n_past,
3291 /*.n_dims =*/ n_dims,
3292 /*.n_ctx_orig =*/ n_ctx_orig,
3293 /*.freq_base =*/ freq_base,
3294 /*.freq_scale =*/ freq_scale,
3295 /*.ext_factor =*/ ext_factor,
3296 /*.attn_factor =*/ attn_factor,
3297 /*.beta_fast =*/ beta_fast,
3298 /*.beta_slow =*/ beta_slow,
3299 /* sect_0 =*/ sect_0,
3300 /* sect_1 =*/ sect_1,
3301 /* sect_2 =*/ sect_2,
3302 /* sect_3 =*/ sect_3,
3303 /* src2 =*/ op->src[2] != nullptr,
3304 };
3305
3306 auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3307
3308 ggml_metal_encoder_set_pipeline(enc, pipeline);
3309 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3310 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3311 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3312 if (op->src[2]) {
3313 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
3314 } else {
3315 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3);
3316 }
3317 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
3318
3319 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3320
3321 return 1;
3322}
3323
3324int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
3325 ggml_tensor * op = ctx->node(idx);
3326
3327 ggml_metal_library_t lib = ctx->lib;
3328 ggml_metal_encoder_t enc = ctx->enc;
3329
3330 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3331 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3332 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3333 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3334
3335 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3336 const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3337 const int32_t p0 = ((const int32_t *)(op->op_params))[2];
3338 const int32_t p1 = ((const int32_t *)(op->op_params))[3];
3339 const int32_t d0 = ((const int32_t *)(op->op_params))[4];
3340 const int32_t d1 = ((const int32_t *)(op->op_params))[5];
3341
3342 const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
3343
3344 const int32_t N = op->src[1]->ne[is_2D ? 3 : 2];
3345 const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
3346 const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
3347 const int32_t IW = op->src[1]->ne[0];
3348
3349 const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
3350 const int32_t KW = op->src[0]->ne[0];
3351
3352 const int32_t OH = is_2D ? op->ne[2] : 1;
3353 const int32_t OW = op->ne[1];
3354
3355 const int32_t CHW = IC * KH * KW;
3356
3357 const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
3358 const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
3359
3360 ggml_metal_kargs_im2col args = {
3361 /*.ofs0 =*/ ofs0,
3362 /*.ofs1 =*/ ofs1,
3363 /*.IW =*/ IW,
3364 /*.IH =*/ IH,
3365 /*.CHW =*/ CHW,
3366 /*.s0 =*/ s0,
3367 /*.s1 =*/ s1,
3368 /*.p0 =*/ p0,
3369 /*.p1 =*/ p1,
3370 /*.d0 =*/ d0,
3371 /*.d1 =*/ d1,
3372 /*.N =*/ N,
3373 /*.KH =*/ KH,
3374 /*.KW =*/ KW,
3375 /*.KHW =*/ KH * KW,
3376 };
3377
3378 auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3379
3380 GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3381
3382 const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3383
3384 ggml_metal_encoder_set_pipeline(enc, pipeline);
3385 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3386 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3387 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3388
3389 ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3390
3391 return 1;
3392}
3393
3394int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3395 ggml_tensor * op = ctx->node(idx);
3396
3397 ggml_metal_library_t lib = ctx->lib;
3398 ggml_metal_encoder_t enc = ctx->enc;
3399
3400 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3401 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3402 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3403 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3404 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3405 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3406
3407 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3408 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3409 GGML_ASSERT(op->type == GGML_TYPE_F32);
3410 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3411
3412 const int32_t s0 = ((const int32_t *) op->op_params)[0];
3413 const int32_t s1 = ((const int32_t *) op->op_params)[1];
3414 const int32_t p0 = ((const int32_t *) op->op_params)[2];
3415 const int32_t p1 = ((const int32_t *) op->op_params)[3];
3416 const int32_t d0 = ((const int32_t *) op->op_params)[4];
3417 const int32_t d1 = ((const int32_t *) op->op_params)[5];
3418
3419 ggml_metal_kargs_conv_2d args = {
3420 /*.nb00 =*/ nb00,
3421 /*.nb01 =*/ nb01,
3422 /*.nb02 =*/ nb02,
3423 /*.nb03 =*/ nb03,
3424 /*.nb10 =*/ nb10,
3425 /*.nb11 =*/ nb11,
3426 /*.nb12 =*/ nb12,
3427 /*.nb13 =*/ nb13,
3428 /*.nb0 =*/ nb0,
3429 /*.nb1 =*/ nb1,
3430 /*.nb2 =*/ nb2,
3431 /*.nb3 =*/ nb3,
3432 /*.IW =*/ ne10,
3433 /*.IH =*/ ne11,
3434 /*.KW =*/ ne00,
3435 /*.KH =*/ ne01,
3436 /*.IC =*/ ne02,
3437 /*.OC =*/ ne03,
3438 /*.OW =*/ ne0,
3439 /*.OH =*/ ne1,
3440 /*.N =*/ ne3,
3441 /*.s0 =*/ s0,
3442 /*.s1 =*/ s1,
3443 /*.p0 =*/ p0,
3444 /*.p1 =*/ p1,
3445 /*.d0 =*/ d0,
3446 /*.d1 =*/ d1,
3447 };
3448
3449 auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3450
3451 int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3452 nth = std::min(nth, 256);
3453 nth = std::max(nth, 1);
3454
3455 const uint64_t n_out = ggml_nelements(op);
3456
3457 uint64_t tg = (n_out + nth - 1)/nth;
3458 tg = std::max<uint64_t>(tg, 1);
3459 tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3460
3461 ggml_metal_encoder_set_pipeline(enc, pipeline);
3462 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3463 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3464 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3465 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3466
3467 ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3468
3469 return 1;
3470}
3471
3472int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3473 ggml_tensor * op = ctx->node(idx);
3474
3475 ggml_metal_library_t lib = ctx->lib;
3476 ggml_metal_encoder_t enc = ctx->enc;
3477
3478 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3479 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3480 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3481 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3482 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3483 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3484
3485 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3486
3487 const int32_t IC = op->src[1]->ne[1];
3488 const int32_t IL = op->src[1]->ne[0];
3489
3490 const int32_t K = op->src[0]->ne[0];
3491
3492 const int32_t OL = op->ne[0];
3493 const int32_t OC = op->ne[1];
3494
3495 ggml_metal_kargs_conv_transpose_1d args = {
3496 /*.IC =*/ IC,
3497 /*.IL =*/ IL,
3498 /*.K =*/ K,
3499 /*.s0 =*/ s0,
3500 /*.nb0 =*/ nb0,
3501 /*.nb1 =*/ nb1,
3502 };
3503
3504 auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3505
3506 ggml_metal_encoder_set_pipeline(enc, pipeline);
3507 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3508 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3509 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3510 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3511
3512 ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3513
3514 return 1;
3515}
3516
3517int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3518 ggml_tensor * op = ctx->node(idx);
3519
3520 ggml_metal_library_t lib = ctx->lib;
3521 ggml_metal_encoder_t enc = ctx->enc;
3522
3523 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3524 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3525 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3526 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3527 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3528 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3529
3530 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3531
3532 const int32_t IC = op->src[1]->ne[2];
3533 const int32_t IH = op->src[1]->ne[1];
3534 const int32_t IW = op->src[1]->ne[0];
3535
3536 const int32_t KH = op->src[0]->ne[1];
3537 const int32_t KW = op->src[0]->ne[0];
3538
3539 const int32_t OW = op->ne[0];
3540 const int32_t OH = op->ne[1];
3541 const int32_t OC = op->ne[2];
3542
3543 ggml_metal_kargs_conv_transpose_2d args = {
3544 /*.IC =*/ IC,
3545 /*.IH =*/ IH,
3546 /*.IW =*/ IW,
3547 /*.KH =*/ KH,
3548 /*.KW =*/ KW,
3549 /*.OC =*/ OC,
3550 /*.s0 =*/ s0,
3551 /*.nb0 =*/ nb0,
3552 /*.nb1 =*/ nb1,
3553 /*.nb2 =*/ nb2,
3554 };
3555
3556 auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3557
3558 ggml_metal_encoder_set_pipeline(enc, pipeline);
3559 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3560 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3561 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3562 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3563
3564 // Metal requires buffer size to be multiple of 16 bytes
3565 const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3566 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3567
3568 ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3569
3570 return 1;
3571}
3572
3573int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3574 ggml_tensor * op = ctx->node(idx);
3575
3576 ggml_metal_library_t lib = ctx->lib;
3577 ggml_metal_encoder_t enc = ctx->enc;
3578
3579 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3580 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3581 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3582 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3583
3584 const float sf0 = (float)ne0/op->src[0]->ne[0];
3585 const float sf1 = (float)ne1/op->src[0]->ne[1];
3586 const float sf2 = (float)ne2/op->src[0]->ne[2];
3587 const float sf3 = (float)ne3/op->src[0]->ne[3];
3588
3589 ggml_metal_kargs_upscale args = {
3590 /*.ne00 =*/ ne00,
3591 /*.ne01 =*/ ne01,
3592 /*.ne02 =*/ ne02,
3593 /*.ne03 =*/ ne03,
3594 /*.nb00 =*/ nb00,
3595 /*.nb01 =*/ nb01,
3596 /*.nb02 =*/ nb02,
3597 /*.nb03 =*/ nb03,
3598 /*.ne0 =*/ ne0,
3599 /*.ne1 =*/ ne1,
3600 /*.ne2 =*/ ne2,
3601 /*.ne3 =*/ ne3,
3602 /*.nb0 =*/ nb0,
3603 /*.nb1 =*/ nb1,
3604 /*.nb2 =*/ nb2,
3605 /*.nb3 =*/ nb3,
3606 /*.sf0 =*/ sf0,
3607 /*.sf1 =*/ sf1,
3608 /*.sf2 =*/ sf2,
3609 /*.sf3 =*/ sf3
3610 };
3611
3612 auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3613
3614 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3615
3616 ggml_metal_encoder_set_pipeline(enc, pipeline);
3617 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3618 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3619 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3620
3621 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3622
3623 return 1;
3624}
3625
3626int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
3627 ggml_tensor * op = ctx->node(idx);
3628
3629 ggml_metal_library_t lib = ctx->lib;
3630 ggml_metal_encoder_t enc = ctx->enc;
3631
3632 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3633 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3634 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3635 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3636
3637 ggml_metal_kargs_pad args = {
3638 /*.ne00 =*/ ne00,
3639 /*.ne01 =*/ ne01,
3640 /*.ne02 =*/ ne02,
3641 /*.ne03 =*/ ne03,
3642 /*.nb00 =*/ nb00,
3643 /*.nb01 =*/ nb01,
3644 /*.nb02 =*/ nb02,
3645 /*.nb03 =*/ nb03,
3646 /*.ne0 =*/ ne0,
3647 /*.ne1 =*/ ne1,
3648 /*.ne2 =*/ ne2,
3649 /*.ne3 =*/ ne3,
3650 /*.nb0 =*/ nb0,
3651 /*.nb1 =*/ nb1,
3652 /*.nb2 =*/ nb2,
3653 /*.nb3 =*/ nb3
3654 };
3655
3656 auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3657
3658 const int nth = std::min(1024, ne0);
3659
3660 ggml_metal_encoder_set_pipeline(enc, pipeline);
3661 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3662 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3663 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3664
3665 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3666
3667 return 1;
3668}
3669
3670int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
3671 ggml_tensor * op = ctx->node(idx);
3672
3673 ggml_metal_library_t lib = ctx->lib;
3674 ggml_metal_encoder_t enc = ctx->enc;
3675
3676 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3677 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3678 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3679 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3680
3681 ggml_metal_kargs_pad_reflect_1d args = {
3682 /*.ne00 =*/ ne00,
3683 /*.ne01 =*/ ne01,
3684 /*.ne02 =*/ ne02,
3685 /*.ne03 =*/ ne03,
3686 /*.nb00 =*/ nb00,
3687 /*.nb01 =*/ nb01,
3688 /*.nb02 =*/ nb02,
3689 /*.nb03 =*/ nb03,
3690 /*.ne0 =*/ ne0,
3691 /*.ne1 =*/ ne1,
3692 /*.ne2 =*/ ne2,
3693 /*.ne3 =*/ ne3,
3694 /*.nb0 =*/ nb0,
3695 /*.nb1 =*/ nb1,
3696 /*.nb2 =*/ nb2,
3697 /*.nb3 =*/ nb3,
3698 /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
3699 /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3700 };
3701
3702 auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3703
3704 const int nth = std::min(1024, ne0);
3705
3706 ggml_metal_encoder_set_pipeline(enc, pipeline);
3707 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3708 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3709 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3710
3711 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3712
3713 return 1;
3714}
3715
3716int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
3717 ggml_tensor * op = ctx->node(idx);
3718
3719 ggml_metal_library_t lib = ctx->lib;
3720 ggml_metal_encoder_t enc = ctx->enc;
3721
3722 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3723 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3724
3725 float start;
3726 float step;
3727
3728 memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
3729 memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float));
3730
3731 ggml_metal_kargs_arange args = {
3732 /*.ne0 =*/ ne0,
3733 /*.start =*/ start,
3734 /*.step =*/ step
3735 };
3736
3737 const int nth = std::min(1024, ne0);
3738
3739 auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
3740
3741 ggml_metal_encoder_set_pipeline(enc, pipeline);
3742 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3743 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
3744
3745 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
3746
3747 return 1;
3748}
3749
3750int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3751 ggml_tensor * op = ctx->node(idx);
3752
3753 ggml_metal_library_t lib = ctx->lib;
3754 ggml_metal_encoder_t enc = ctx->enc;
3755
3756 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3757 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3758 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3759 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3760
3761 const int dim = op->op_params[0];
3762 const int max_period = op->op_params[1];
3763
3764 ggml_metal_kargs_timestep_embedding args = {
3765 /*.nb1 =*/ nb1,
3766 /*.dim =*/ dim,
3767 /*.max_period =*/ max_period,
3768 };
3769
3770 auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3771
3772 const int nth = std::max(1, std::min(1024, dim/2));
3773
3774 ggml_metal_encoder_set_pipeline(enc, pipeline);
3775 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3776 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3777 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3778
3779 ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
3780
3781 return 1;
3782}
3783
3784int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3785 ggml_tensor * op = ctx->node(idx);
3786
3787 ggml_metal_library_t lib = ctx->lib;
3788 ggml_metal_encoder_t enc = ctx->enc;
3789
3790 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3791 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3792 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3793 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3794
3795 ggml_metal_kargs_argmax args = {
3796 /*.ne00 = */ ne00,
3797 /*.nb01 = */ nb01,
3798 };
3799
3800 auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3801
3802 const int64_t nrows = ggml_nrows(op->src[0]);
3803
3804 int nth = 32; // SIMD width
3805 while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3806 nth *= 2;
3807 }
3808
3809 const size_t smem = pipeline.smem;
3810
3811 ggml_metal_encoder_set_pipeline(enc, pipeline);
3812 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3813 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3814 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3815
3816 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3817
3818 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3819
3820 return 1;
3821}
3822
3823int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3824 ggml_tensor * op = ctx->node(idx);
3825
3826 ggml_metal_library_t lib = ctx->lib;
3827 ggml_metal_encoder_t enc = ctx->enc;
3828
3829 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3830
3831 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3832 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3833 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3834 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3835
3836 auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3837
3838 // bitonic sort requires the number of elements to be power of 2
3839 int nth = 1;
3840 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3841 nth *= 2;
3842 }
3843
3844 const int npr = (ne00 + nth - 1)/nth;
3845
3846 // Metal kernels require the buffer size to be multiple of 16 bytes
3847 // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3848 const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3849
3850 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3851 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3852
3853 ggml_metal_buffer_id bid_tmp = bid_dst;
3854 bid_tmp.offs += ggml_nbytes(op);
3855
3856 if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3857 std::swap(bid_dst, bid_tmp);
3858 }
3859
3860 ggml_metal_kargs_argsort args = {
3861 /*.ne00 =*/ ne00,
3862 /*.ne01 =*/ ne01,
3863 /*.ne02 =*/ ne02,
3864 /*.ne03 =*/ ne03,
3865 /*.nb00 =*/ nb00,
3866 /*.nb01 =*/ nb01,
3867 /*.nb02 =*/ nb02,
3868 /*.nb03 =*/ nb03,
3869 /*.ne0 =*/ ne0,
3870 /*.ne1 =*/ ne1,
3871 /*.ne2 =*/ ne2,
3872 /*.ne3 =*/ ne3,
3873 /*.top_k =*/ nth,
3874 };
3875
3876 ggml_metal_encoder_set_pipeline(enc, pipeline);
3877 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3878 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3879 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3880
3881 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3882
3883 ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3884
3885 auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3886
3887 int len = nth;
3888
3889 while (len < ne00) {
3890 ggml_metal_op_concurrency_reset(ctx);
3891
3892 ggml_metal_kargs_argsort_merge args_merge = {
3893 /*.ne00 =*/ ne00,
3894 /*.ne01 =*/ ne01,
3895 /*.ne02 =*/ ne02,
3896 /*.ne03 =*/ ne03,
3897 /*.nb00 =*/ nb00,
3898 /*.nb01 =*/ nb01,
3899 /*.nb02 =*/ nb02,
3900 /*.nb03 =*/ nb03,
3901 /*.ne0 =*/ ne0,
3902 /*.ne1 =*/ ne1,
3903 /*.ne2 =*/ ne2,
3904 /*.ne3 =*/ ne3,
3905 /*.top_k =*/ ne00,
3906 /*.len =*/ len,
3907 };
3908
3909 // merges per row
3910 const int nm = (ne00 + 2*len - 1) / (2*len);
3911
3912 const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3913
3914 ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3915 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3916 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3917 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3918 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3919
3920 ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3921
3922 std::swap(bid_dst, bid_tmp);
3923
3924 len <<= 1;
3925 }
3926
3927 return 1;
3928}
3929
3930int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3931 ggml_tensor * op = ctx->node(idx);
3932
3933 ggml_metal_library_t lib = ctx->lib;
3934 ggml_metal_encoder_t enc = ctx->enc;
3935
3936 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3937
3938 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3939 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3940 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3941 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3942
3943 auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
3944
3945 // bitonic sort requires the number of elements to be power of 2
3946 int nth = 1;
3947 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3948 nth *= 2;
3949 }
3950
3951 // blocks per row
3952 const int npr = (ne00 + nth - 1)/nth;
3953
3954 const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3955
3956 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3957 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3958
3959 ggml_metal_buffer_id bid_tmp = bid_dst;
3960 bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3961
3962 if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3963 std::swap(bid_dst, bid_tmp);
3964 }
3965
3966 const int top_k = ne0;
3967
3968 ggml_metal_kargs_argsort args = {
3969 /*.ne00 =*/ ne00,
3970 /*.ne01 =*/ ne01,
3971 /*.ne02 =*/ ne02,
3972 /*.ne03 =*/ ne03,
3973 /*.nb00 =*/ nb00,
3974 /*.nb01 =*/ nb01,
3975 /*.nb02 =*/ nb02,
3976 /*.nb03 =*/ nb03,
3977 /*.ne0 =*/ ne0,
3978 /*.ne1 =*/ ne1,
3979 /*.ne2 =*/ ne2,
3980 /*.ne3 =*/ ne3,
3981 /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3982 };
3983
3984 if (npr > 1) {
3985 args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3986 }
3987
3988 ggml_metal_encoder_set_pipeline(enc, pipeline);
3989 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3990 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3991 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3992
3993 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3994
3995 ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3996
3997 auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3998
3999 int len = args.top_k;
4000
4001 while (len < args.ne0) {
4002 ggml_metal_op_concurrency_reset(ctx);
4003
4004 // merges per row
4005 const int nm = (args.ne0 + 2*len - 1) / (2*len);
4006
4007 const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
4008
4009 ggml_metal_kargs_argsort_merge args_merge = {
4010 /*.ne00 =*/ ne00,
4011 /*.ne01 =*/ ne01,
4012 /*.ne02 =*/ ne02,
4013 /*.ne03 =*/ ne03,
4014 /*.nb00 =*/ nb00,
4015 /*.nb01 =*/ nb01,
4016 /*.nb02 =*/ nb02,
4017 /*.nb03 =*/ nb03,
4018 /*.ne0 =*/ args.ne0,
4019 /*.ne1 =*/ ne1,
4020 /*.ne2 =*/ ne2,
4021 /*.ne3 =*/ ne3,
4022 /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
4023 /*.len =*/ len,
4024 };
4025
4026 ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
4027 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
4028 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4029 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
4030 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
4031
4032 ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
4033
4034 std::swap(bid_dst, bid_tmp);
4035
4036 len <<= 1;
4037 }
4038
4039 return 1;
4040}
4041
4042int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
4043 ggml_tensor * op = ctx->node(idx);
4044
4045 ggml_metal_library_t lib = ctx->lib;
4046 ggml_metal_encoder_t enc = ctx->enc;
4047
4048 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4049 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4050 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4051 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4052
4053 ggml_metal_kargs_tri args = {
4054 /*.ne00 =*/ ne00,
4055 /*.ne01 =*/ ne01,
4056 /*.ne02 =*/ ne02,
4057 /*.ne03 =*/ ne03,
4058 /*.nb00 =*/ nb00,
4059 /*.nb01 =*/ nb01,
4060 /*.nb02 =*/ nb02,
4061 /*.nb03 =*/ nb03,
4062 /*.ne0 =*/ ne0,
4063 /*.ne1 =*/ ne1,
4064 /*.ne2 =*/ ne2,
4065 /*.ne3 =*/ ne3,
4066 /*.nb0 =*/ nb0,
4067 /*.nb1 =*/ nb1,
4068 /*.nb2 =*/ nb2,
4069 /*.nb3 =*/ nb3,
4070 };
4071
4072 auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
4073
4074 int nth = 32; // SIMD width
4075
4076 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4077 nth *= 2;
4078 }
4079
4080 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4081 nth = std::min(nth, ne00);
4082
4083 ggml_metal_encoder_set_pipeline(enc, pipeline);
4084 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4085 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4086 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
4087
4088 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4089
4090 return 1;
4091}
4092
4093int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4094 ggml_tensor * op = ctx->node(idx);
4095
4096 ggml_metal_library_t lib = ctx->lib;
4097 ggml_metal_encoder_t enc = ctx->enc;
4098
4099 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4100 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4101 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4102 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4103
4104 auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4105
4106 const int64_t np = ggml_nelements(op->src[0]);
4107 ggml_metal_kargs_opt_step_adamw args = {
4108 /*.np =*/ np,
4109 };
4110
4111 int ida = 0;
4112
4113 ggml_metal_encoder_set_pipeline(enc, pipeline);
4114 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4115 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4116 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4117 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4118 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4119 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4120
4121 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4122 const int64_t n = (np + nth - 1) / nth;
4123
4124 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4125
4126 return 1;
4127}
4128
4129int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4130 ggml_tensor * op = ctx->node(idx);
4131
4132 ggml_metal_library_t lib = ctx->lib;
4133 ggml_metal_encoder_t enc = ctx->enc;
4134
4135 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4136 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4137 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4138 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4139
4140 auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4141
4142 const int64_t np = ggml_nelements(op->src[0]);
4143 ggml_metal_kargs_opt_step_sgd args = {
4144 /*.np =*/ np,
4145 };
4146
4147 int ida = 0;
4148
4149 ggml_metal_encoder_set_pipeline(enc, pipeline);
4150 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4151 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4152 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4153 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4154
4155 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4156 const int64_t n = (np + nth - 1) / nth;
4157
4158 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4159
4160 return 1;
4161}
4162
4163int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4164 ggml_tensor * op = ctx->node(idx);
4165
4166 ggml_metal_library_t lib = ctx->lib;
4167 ggml_metal_encoder_t enc = ctx->enc;
4168
4169 GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4170 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4171 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4172
4173 {
4174 ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4175
4176 auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4177
4178 ggml_metal_encoder_set_pipeline(enc, pipeline);
4179 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4180 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4181
4182 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4183 }
4184
4185 ggml_metal_op_concurrency_reset(ctx);
4186
4187 {
4188 ggml_metal_kargs_count_equal args = {
4189 /*.ne00 =*/ ne00,
4190 /*.ne01 =*/ ne01,
4191 /*.ne02 =*/ ne02,
4192 /*.ne03 =*/ ne03,
4193 /*.nb00 =*/ nb00,
4194 /*.nb01 =*/ nb01,
4195 /*.nb02 =*/ nb02,
4196 /*.nb03 =*/ nb03,
4197 /*.nb10 =*/ nb10,
4198 /*.nb11 =*/ nb11,
4199 /*.nb12 =*/ nb12,
4200 /*.nb13 =*/ nb13,
4201 };
4202
4203 auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4204
4205 const size_t smem = pipeline.smem;
4206
4207 const int nth = 32*pipeline.nsg;
4208
4209 GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4210
4211 ggml_metal_encoder_set_pipeline(enc, pipeline);
4212 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4213 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4214 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4215 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4216
4217 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4218 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4219 }
4220
4221 return 1;
4222}