1#include "ggml-metal-device.h"
2
3#include "ggml-metal-impl.h"
4
5#include "ggml-impl.h"
6
7#include <cassert>
8#include <memory>
9#include <string>
10#include <unordered_map>
11
12struct ggml_metal_device_deleter {
13 void operator()(ggml_metal_device_t ctx) {
14 ggml_metal_device_free(ctx);
15 }
16};
17
18typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
19
20ggml_metal_device_t ggml_metal_device_get(int device) {
21 static std::vector<ggml_metal_device_ptr> devs;
22
23 devs.emplace_back(ggml_metal_device_init(device));
24
25 return devs.back().get();
26}
27
28struct ggml_metal_pipelines {
29 std::unordered_map<std::string, ggml_metal_pipeline_t> data;
30};
31
32ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
33 ggml_metal_pipelines_t res = new ggml_metal_pipelines();
34
35 return res;
36}
37
38void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
39 if (!ppls) {
40 return;
41 }
42
43 for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
44 ggml_metal_pipeline_free(it->second);
45 }
46
47 delete ppls;
48}
49
50void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {
51 ppls->data[name] = pipeline;
52}
53
54ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
55 if (ppls->data.find(name) == ppls->data.end()) {
56 return nullptr;
57 }
58
59 return ppls->data[name];
60}
61
62struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
63 char base[256];
64 char name[256];
65
66 const char * op_str = "undefined";
67 switch (op) {
68 case GGML_OP_ADD_ID: op_str = "add_id"; break;
69 case GGML_OP_CONCAT: op_str = "concat"; break;
70 default: GGML_ABORT("fatal error");
71 };
72
73 snprintf(base, 256, "kernel_%s", op_str);
74 snprintf(name, 256, "%s", base);
75
76 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
77 if (!res.pipeline) {
78 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
79 }
80
81 return res;
82}
83
84ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
85 char base[256];
86 char name[256];
87
88 snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
89 snprintf(name, 256, "%s", base);
90
91 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
92 if (!res.pipeline) {
93 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
94 }
95
96 return res;
97}
98
99ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
100 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
101 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
102
103 const char * pool_str = "undefined";
104 switch (op_pool) {
105 case GGML_OP_POOL_AVG: pool_str = "avg"; break;
106 case GGML_OP_POOL_MAX: pool_str = "max"; break;
107 default: GGML_ASSERT(false && "not implemented");
108 };
109
110 char base[256];
111 char name[256];
112
113 snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
114 snprintf(name, sizeof(name), "%s", base);
115
116 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
117 if (!res.pipeline) {
118 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
119 }
120
121 return res;
122}
123
124ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
125 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
126 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
127
128 const char * pool_str = "undefined";
129 switch (op_pool) {
130 case GGML_OP_POOL_AVG: pool_str = "avg"; break;
131 case GGML_OP_POOL_MAX: pool_str = "max"; break;
132 default: GGML_ASSERT(false && "not implemented");
133 };
134
135 char base[256];
136 char name[256];
137
138 snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
139 snprintf(name, 256, "%s", base);
140
141 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
142 if (!res.pipeline) {
143 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
144 }
145
146 return res;
147}
148
149ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
150 char base[256];
151 char name[256];
152
153 snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
154 snprintf(name, 256, "%s", base);
155
156 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
157 if (!res.pipeline) {
158 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
159 }
160
161 return res;
162}
163
164ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
165 char base[256];
166 char name[256];
167
168 snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
169 snprintf(name, 256, "%s", base);
170
171 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
172 if (!res.pipeline) {
173 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
174 }
175
176 return res;
177}
178
179ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
180 char base[256];
181 char name[256];
182
183 const int n = op->src[0]->ne[0];
184
185 snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
186 snprintf(name, 256, "%s_n=%d", base, n);
187
188 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
189 if (!res.pipeline) {
190 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
191 }
192
193 res.nsg = 1;
194 res.smem = 0;
195
196 return res;
197}
198
199ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
200 char base[256];
201 char name[256];
202
203 snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
204 snprintf(name, 256, "%s", base);
205
206 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
207 if (!res.pipeline) {
208 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
209 }
210
211 return res;
212}
213
214ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
215 char base[256];
216 char name[256];
217
218 int op_num = -1;
219
220 switch (op->op) {
221 case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
222 case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
223 case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
224 case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
225 case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
226 case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
227 case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
228 case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
229 case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
230 case GGML_OP_UNARY:
231 switch (ggml_get_unary_op(op)) {
232 case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
233 case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
234 case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
235 case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
236 case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
237 case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
238 case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
239 case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
240 case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
241 case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
242 case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
243 case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
244 case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
245 case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
246 case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
247 case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
248 case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
249 default: GGML_ABORT("fatal error");
250 } break;
251 default: GGML_ABORT("fatal error");
252 };
253
254 const char * t0_str = ggml_type_name(op->src[0]->type);
255 const char * t_str = ggml_type_name(op->type);
256
257 const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
258 const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
259
260 snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
261 snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
262
263 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
264 if (!res.pipeline) {
265 ggml_metal_cv_t cv = ggml_metal_cv_init();
266
267 ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
268 ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
269
270 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
271
272 ggml_metal_cv_free(cv);
273 }
274
275 res.c4 = is_c4;
276 res.cnt = is_cnt;
277
278 return res;
279}
280
281ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
282 GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
283
284 char base[256];
285 char name[256];
286
287 const char * op_str = "undefined";
288 switch (op->op) {
289 case GGML_OP_GLU:
290 switch (ggml_get_glu_op(op)) {
291 case GGML_GLU_OP_REGLU: op_str = "reglu"; break;
292 case GGML_GLU_OP_GEGLU: op_str = "geglu"; break;
293 case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break;
294 case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break;
295 case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break;
296 case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break;
297 default: GGML_ABORT("fatal error");
298 } break;
299 default: GGML_ABORT("fatal error");
300 };
301
302 snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
303 snprintf(name, 256, "%s", base);
304
305 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
306 if (!res.pipeline) {
307 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
308 }
309
310 return res;
311}
312
313ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
314 assert(op->op == GGML_OP_SUM);
315
316 char base[256];
317 char name[256];
318
319 snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
320 snprintf(name, 256, "%s", base);
321
322 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
323 if (!res.pipeline) {
324 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
325 }
326
327 return res;
328}
329
330ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
331 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
332
333 char base[256];
334 char name[256];
335
336 const char * op_str = "undefined";
337 switch (op->op) {
338 case GGML_OP_SUM_ROWS:
339 op_str = "sum_rows"; break;
340 case GGML_OP_MEAN:
341 op_str = "mean"; break;
342 default: GGML_ABORT("fatal error");
343 };
344
345 snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
346
347 snprintf(name, 256, "%s", base);
348
349 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
350 if (!res.pipeline) {
351 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
352 }
353
354 res.smem = 32*sizeof(float);
355
356 return res;
357}
358
359ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
360 GGML_ASSERT(op->op == GGML_OP_CUMSUM);
361
362 char base[256];
363 char name[256];
364
365 snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
366 snprintf(name, 256, "%s", base);
367
368 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
369 if (!res.pipeline) {
370 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371 }
372
373 return res;
374}
375
376ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
377 GGML_ASSERT(op->op == GGML_OP_CUMSUM);
378
379 char base[256];
380 char name[256];
381
382 snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
383 snprintf(name, 256, "%s", base);
384
385 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
386 if (!res.pipeline) {
387 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
388 }
389
390 return res;
391}
392
393ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
394 GGML_ASSERT(op->op == GGML_OP_TRI);
395 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
396
397 char base[256];
398 char name[256];
399
400 const char * op_str = "tri";
401 const int ttype = op->op_params[0];
402
403 snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
404
405 snprintf(name, 256, "%s", base);
406
407 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
408 if (!res.pipeline) {
409 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
410 }
411
412 return res;
413}
414
415ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
416 GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
417
418 char base[256];
419 char name[256];
420
421 const char * suffix = "";
422
423 if (op->src[0]->ne[0] % 4 == 0) {
424 suffix = "_4";
425 }
426
427 const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32;
428
429 snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
430 snprintf(name, 256, "%s", base);
431
432 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
433 if (!res.pipeline) {
434 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
435 }
436
437 res.smem = 32*sizeof(float);
438
439 return res;
440}
441
442ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
443 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
444 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
445
446 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
447 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
448
449 char base[256];
450 char name[256];
451
452 const char * suffix = "";
453
454 if (op->src[1]->ne[0] % 4 == 0) {
455 suffix = "_4";
456 }
457
458 snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
459 snprintf(name, 256, "%s", base);
460
461 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
462 if (!res.pipeline) {
463 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
464 }
465
466 return res;
467}
468
469ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
470 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
471 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472
473 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
474 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
475
476 char base[256];
477 char name[256];
478
479 const char * suffix = "";
480 if (op->src[1]->ne[0] % 4 == 0) {
481 suffix = "_4";
482 }
483
484 snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
485 snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
486
487 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
488 if (!res.pipeline) {
489 ggml_metal_cv_t cv = ggml_metal_cv_init();
490
491 ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
492
493 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
494
495 ggml_metal_cv_free(cv);
496 }
497
498 return res;
499}
500
501ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
502 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
503
504 char base[256];
505 char name[256];
506
507 const int nsg = (ne00 + 31)/32;
508
509 snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
510 snprintf(name, 256, "%s_nsg=%d", base, nsg);
511
512 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
513 if (!res.pipeline) {
514 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
515 }
516
517 // Shared memory layout:
518 // - sgptg * NW floats for partial sums (nsg * 32)
519 // - sgptg floats for shared_x_dt (nsg)
520 // - sgptg floats for shared_dA (nsg)
521 // Total: nsg * (32 + 2) floats
522 res.smem = (32 + 2)*sizeof(float)*nsg;
523
524 return res;
525}
526
527ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
528 char base[256];
529 char name[256];
530
531 const int64_t C = op->ne[0];
532 const int64_t H = op->src[0]->ne[1];
533
534 switch (op->op) {
535 case GGML_OP_RWKV_WKV6:
536 {
537 GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
538 GGML_ASSERT(C % H == 0);
539 GGML_ASSERT(C / H == 64);
540
541 snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type));
542 } break;
543 case GGML_OP_RWKV_WKV7:
544 {
545 GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32);
546 GGML_ASSERT(C % H == 0);
547 GGML_ASSERT(C / H == 64);
548
549 snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type));
550 } break;
551 default:
552 GGML_ABORT("fatal error");
553 }
554
555 snprintf(name, 256, "%s", base);
556
557 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
558 if (!res.pipeline) {
559 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
560 }
561
562 return res;
563}
564
565ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
566 char base[256];
567 char name[256];
568
569 const int nsg = 8;
570 const int n = op->src[1]->ne[1];
571 const int k = op->src[1]->ne[0];
572
573 snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
574 snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
575
576 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
577 if (!res.pipeline) {
578 ggml_metal_cv_t cv = ggml_metal_cv_init();
579
580 ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
581 ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1);
582 ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2);
583
584 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
585
586 ggml_metal_cv_free(cv);
587 }
588
589 res.nsg = nsg;
590 res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
591
592 return res;
593}
594
595ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
596 char base[256];
597 char name[256];
598
599 snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
600 snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
601
602 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
603 if (!res.pipeline) {
604 ggml_metal_cv_t cv = ggml_metal_cv_init();
605
606 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
607 ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
608
609 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
610
611 ggml_metal_cv_free(cv);
612 }
613
614 return res;
615}
616
617ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
618 char base[256];
619 char name[256];
620
621 const ggml_type tsrc0 = op->src[0]->type;
622 const ggml_type tsrc1 = op->src[1]->type;
623
624 const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
625 const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
626
627 snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
628 snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
629
630 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
631 if (!res.pipeline) {
632 ggml_metal_cv_t cv = ggml_metal_cv_init();
633
634 ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
635 ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
636
637 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
638
639 ggml_metal_cv_free(cv);
640 }
641
642 // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
643 res.smem = bc_out ? 8192 : 4096 + 2048;
644
645 return res;
646}
647
648ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
649 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
650 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
651
652 char base[256];
653 char name[256];
654
655 int nsg = 0; // number of simdgroups
656 int nr0 = 0; // number of src0 rows per simdgroup
657 int nr1 = 1; // number of src1 rows per threadgroup
658
659 size_t smem = 0; // shared memory
660
661 const ggml_type tsrc0 = op->src[0]->type;
662 const ggml_type tsrc1 = op->src[1]->type;
663
664 const char * suffix = "";
665
666 // use custom matrix x vector kernel
667 switch (tsrc0) {
668 case GGML_TYPE_F32:
669 case GGML_TYPE_F16:
670 case GGML_TYPE_BF16:
671 {
672 if (ne00 < 32) {
673 nsg = 1;
674 nr0 = 32;
675 nr1 = 1;
676 suffix = "_short";
677 } else {
678 nsg = std::min(4, (ne00 + 127) / 128);
679 nr0 = 2;
680 nr1 = 1;
681 smem = 32*sizeof(float)*nr0;
682 suffix = ne00 % 4 == 0 ? "_4" : "";
683 }
684 } break;
685 case GGML_TYPE_Q4_0:
686 {
687 nsg = N_SG_Q4_0;
688 nr0 = N_R0_Q4_0;
689 } break;
690 case GGML_TYPE_Q4_1:
691 {
692 nsg = N_SG_Q4_1;
693 nr0 = N_R0_Q4_1;
694 } break;
695 case GGML_TYPE_Q5_0:
696 {
697 nsg = N_SG_Q5_0;
698 nr0 = N_R0_Q5_0;
699 } break;
700 case GGML_TYPE_Q5_1:
701 {
702 nsg = N_SG_Q5_1;
703 nr0 = N_R0_Q5_1;
704 } break;
705 case GGML_TYPE_Q8_0:
706 {
707 nsg = N_SG_Q8_0;
708 nr0 = N_R0_Q8_0;
709 smem = 32*sizeof(float)*N_R0_Q8_0;
710 } break;
711 case GGML_TYPE_MXFP4:
712 {
713 nsg = N_SG_MXFP4;
714 nr0 = N_R0_MXFP4;
715 smem = 32*sizeof(float);
716 } break;
717 case GGML_TYPE_Q2_K:
718 {
719 nsg = N_SG_Q2_K;
720 nr0 = N_R0_Q2_K;
721 } break;
722 case GGML_TYPE_Q3_K:
723 {
724 nsg = N_SG_Q3_K;
725 nr0 = N_R0_Q3_K;
726 } break;
727 case GGML_TYPE_Q4_K:
728 {
729 nsg = N_SG_Q4_K;
730 nr0 = N_R0_Q4_K;
731 } break;
732 case GGML_TYPE_Q5_K:
733 {
734 nsg = N_SG_Q5_K;
735 nr0 = N_R0_Q5_K;
736 } break;
737 case GGML_TYPE_Q6_K:
738 {
739 nsg = N_SG_Q6_K;
740 nr0 = N_R0_Q6_K;
741 } break;
742 case GGML_TYPE_IQ2_XXS:
743 {
744 nsg = N_SG_IQ2_XXS;
745 nr0 = N_R0_IQ2_XXS;
746 smem = 256*8+128;
747 } break;
748 case GGML_TYPE_IQ2_XS:
749 {
750 nsg = N_SG_IQ2_XS;
751 nr0 = N_R0_IQ2_XS;
752 smem = 512*8+128;
753 } break;
754 case GGML_TYPE_IQ3_XXS:
755 {
756 nsg = N_SG_IQ3_XXS;
757 nr0 = N_R0_IQ3_XXS;
758 smem = 256*4+128;
759 } break;
760 case GGML_TYPE_IQ3_S:
761 {
762 nsg = N_SG_IQ3_S;
763 nr0 = N_R0_IQ3_S;
764 smem = 512*4;
765 } break;
766 case GGML_TYPE_IQ2_S:
767 {
768 nsg = N_SG_IQ2_S;
769 nr0 = N_R0_IQ2_S;
770 } break;
771 case GGML_TYPE_IQ1_S:
772 {
773 nsg = N_SG_IQ1_S;
774 nr0 = N_R0_IQ1_S;
775 } break;
776 case GGML_TYPE_IQ1_M:
777 {
778 nsg = N_SG_IQ1_M;
779 nr0 = N_R0_IQ1_M;
780 } break;
781 case GGML_TYPE_IQ4_NL:
782 {
783 nsg = N_SG_IQ4_NL;
784 nr0 = N_R0_IQ4_NL;
785 smem = 32*sizeof(float);
786 } break;
787 case GGML_TYPE_IQ4_XS:
788 {
789 nsg = N_SG_IQ4_XS;
790 nr0 = N_R0_IQ4_XS;
791 smem = 32*sizeof(float);
792 } break;
793 default:
794 {
795 GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0);
796 GGML_ABORT("not implemented");
797 }
798 };
799
800 snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
801 snprintf(name, 256, "%s_nsg=%d", base, nsg);
802
803 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
804 if (!res.pipeline) {
805 ggml_metal_cv_t cv = ggml_metal_cv_init();
806
807 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
808
809 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
810
811 ggml_metal_cv_free(cv);
812 }
813
814 res.nr0 = nr0;
815 res.nr1 = nr1;
816 res.nsg = nsg;
817 res.smem = smem;
818
819 return res;
820}
821
822ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
823 char base[256];
824 char name[256];
825
826 snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
827 snprintf(name, 256, "%s_ne02=%d", base, ne02);
828
829 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
830 if (!res.pipeline) {
831 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
832 }
833
834 res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
835
836 return res;
837}
838
839ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
840 char base[256];
841 char name[256];
842
843 const ggml_type tsrc0 = op->src[0]->type;
844 const ggml_type tsrc1 = op->src[1]->type;
845
846 const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
847
848 snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
849 snprintf(name, 256, "%s_bci=%d", base, bc_inp);
850
851 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
852 if (!res.pipeline) {
853 ggml_metal_cv_t cv = ggml_metal_cv_init();
854
855 ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
856
857 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
858
859 ggml_metal_cv_free(cv);
860 }
861
862 res.smem = 8192;
863
864 return res;
865}
866
867ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
868 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
869 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
870
871 char base[256];
872 char name[256];
873
874 int nsg = 0; // number of simdgroups
875 int nr0 = 0; // number of src0 rows per simdgroup
876 int nr1 = 1; // number of src1 rows per threadgroup
877
878 size_t smem = 0; // shared memory
879
880 const ggml_type tsrc0 = op->src[0]->type;
881 const ggml_type tsrc1 = op->src[1]->type;
882
883 const char * suffix = "";
884
885 // use custom matrix x vector kernel
886 switch (tsrc0) {
887 case GGML_TYPE_F32:
888 case GGML_TYPE_F16:
889 case GGML_TYPE_BF16:
890 {
891 nsg = std::min(4, (ne00 + 127) / 128);
892 nr0 = 2;
893 nr1 = 1;
894 smem = 32*sizeof(float)*nr0;
895 suffix = ne00 % 4 == 0 ? "_4" : "";
896 } break;
897 case GGML_TYPE_Q4_0:
898 {
899 nsg = N_SG_Q4_0;
900 nr0 = N_R0_Q4_0;
901 } break;
902 case GGML_TYPE_Q4_1:
903 {
904 nsg = N_SG_Q4_1;
905 nr0 = N_R0_Q4_1;
906 } break;
907 case GGML_TYPE_Q5_0:
908 {
909 nsg = N_SG_Q5_0;
910 nr0 = N_R0_Q5_0;
911 } break;
912 case GGML_TYPE_Q5_1:
913 {
914 nsg = N_SG_Q5_1;
915 nr0 = N_R0_Q5_1;
916 } break;
917 case GGML_TYPE_Q8_0:
918 {
919 nsg = N_SG_Q8_0;
920 nr0 = N_R0_Q8_0;
921 smem = 32*sizeof(float)*N_R0_Q8_0;
922 } break;
923 case GGML_TYPE_MXFP4:
924 {
925 nsg = N_SG_MXFP4;
926 nr0 = N_R0_MXFP4;
927 smem = 32*sizeof(float);
928 } break;
929 case GGML_TYPE_Q2_K:
930 {
931 nsg = N_SG_Q2_K;
932 nr0 = N_R0_Q2_K;
933 } break;
934 case GGML_TYPE_Q3_K:
935 {
936 nsg = N_SG_Q3_K;
937 nr0 = N_R0_Q3_K;
938 } break;
939 case GGML_TYPE_Q4_K:
940 {
941 nsg = N_SG_Q4_K;
942 nr0 = N_R0_Q4_K;
943 } break;
944 case GGML_TYPE_Q5_K:
945 {
946 nsg = N_SG_Q5_K;
947 nr0 = N_R0_Q5_K;
948 } break;
949 case GGML_TYPE_Q6_K:
950 {
951 nsg = N_SG_Q6_K;
952 nr0 = N_R0_Q6_K;
953 } break;
954 case GGML_TYPE_IQ2_XXS:
955 {
956 nsg = N_SG_IQ2_XXS;
957 nr0 = N_R0_IQ2_XXS;
958 smem = 256*8+128;
959 } break;
960 case GGML_TYPE_IQ2_XS:
961 {
962 nsg = N_SG_IQ2_XS;
963 nr0 = N_R0_IQ2_XS;
964 smem = 512*8+128;
965 } break;
966 case GGML_TYPE_IQ3_XXS:
967 {
968 nsg = N_SG_IQ3_XXS;
969 nr0 = N_R0_IQ3_XXS;
970 smem = 256*4+128;
971 } break;
972 case GGML_TYPE_IQ3_S:
973 {
974 nsg = N_SG_IQ3_S;
975 nr0 = N_R0_IQ3_S;
976 smem = 512*4;
977 } break;
978 case GGML_TYPE_IQ2_S:
979 {
980 nsg = N_SG_IQ2_S;
981 nr0 = N_R0_IQ2_S;
982 } break;
983 case GGML_TYPE_IQ1_S:
984 {
985 nsg = N_SG_IQ1_S;
986 nr0 = N_R0_IQ1_S;
987 } break;
988 case GGML_TYPE_IQ1_M:
989 {
990 nsg = N_SG_IQ1_M;
991 nr0 = N_R0_IQ1_M;
992 } break;
993 case GGML_TYPE_IQ4_NL:
994 {
995 nsg = N_SG_IQ4_NL;
996 nr0 = N_R0_IQ4_NL;
997 smem = 32*sizeof(float);
998 } break;
999 case GGML_TYPE_IQ4_XS:
1000 {
1001 nsg = N_SG_IQ4_XS;
1002 nr0 = N_R0_IQ4_XS;
1003 smem = 32*sizeof(float);
1004 } break;
1005 default:
1006 {
1007 GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type);
1008 GGML_ABORT("not implemented");
1009 }
1010 };
1011
1012 snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
1013 snprintf(name, 256, "%s_nsg=%d", base, nsg);
1014
1015 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1016 if (!res.pipeline) {
1017 ggml_metal_cv_t cv = ggml_metal_cv_init();
1018
1019 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
1020
1021 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1022
1023 ggml_metal_cv_free(cv);
1024 }
1025
1026 res.nr0 = nr0;
1027 res.nr1 = nr1;
1028 res.nsg = nsg;
1029 res.smem = smem;
1030
1031 return res;
1032}
1033
1034ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
1035 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
1036 GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
1037 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
1038
1039 char base[256];
1040 char name[256];
1041
1042 snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
1043 snprintf(name, 256, "%s", base);
1044
1045 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1046 if (!res.pipeline) {
1047 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1048 }
1049
1050 res.smem = 32*(sizeof(float) + sizeof(int32_t));
1051
1052 return res;
1053}
1054
1055ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
1056 assert(op->op == GGML_OP_ARGSORT);
1057
1058 char base[256];
1059 char name[256];
1060
1061 ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1062
1063 const char * order_str = "undefined";
1064 switch (order) {
1065 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1066 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1067 default: GGML_ABORT("fatal error");
1068 };
1069
1070 snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1071 snprintf(name, 256, "%s", base);
1072
1073 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1074 if (!res.pipeline) {
1075 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1076 }
1077
1078 return res;
1079}
1080
1081ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1082 assert(op->op == GGML_OP_ARGSORT);
1083
1084 char base[256];
1085 char name[256];
1086
1087 ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1088
1089 const char * order_str = "undefined";
1090 switch (order) {
1091 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1092 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1093 default: GGML_ABORT("fatal error");
1094 };
1095
1096 snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1097 snprintf(name, 256, "%s", base);
1098
1099 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1100 if (!res.pipeline) {
1101 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1102 }
1103
1104 return res;
1105}
1106
1107// note: reuse the argsort kernel for top_k
1108ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1109 assert(op->op == GGML_OP_TOP_K);
1110
1111 char base[256];
1112 char name[256];
1113
1114 // note: the top_k kernel is always descending order
1115 ggml_sort_order order = GGML_SORT_ORDER_DESC;
1116
1117 const char * order_str = "undefined";
1118 switch (order) {
1119 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1120 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1121 default: GGML_ABORT("fatal error");
1122 };
1123
1124 snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1125 snprintf(name, 256, "%s", base);
1126
1127 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1128 if (!res.pipeline) {
1129 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1130 }
1131
1132 return res;
1133}
1134
1135ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1136 assert(op->op == GGML_OP_TOP_K);
1137
1138 char base[256];
1139 char name[256];
1140
1141 ggml_sort_order order = GGML_SORT_ORDER_DESC;
1142
1143 const char * order_str = "undefined";
1144 switch (order) {
1145 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1146 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1147 default: GGML_ABORT("fatal error");
1148 };
1149
1150 snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1151 snprintf(name, 256, "%s", base);
1152
1153 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1154 if (!res.pipeline) {
1155 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1156 }
1157
1158 return res;
1159}
1160
1161ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1162 ggml_metal_library_t lib,
1163 const struct ggml_tensor * op,
1164 bool has_mask,
1165 int32_t ncpsg) {
1166 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1167 GGML_UNUSED(op);
1168
1169 char base[256];
1170 char name[256];
1171
1172 snprintf(base, 256, "kernel_%s",
1173 "flash_attn_ext_pad");
1174
1175 snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1176 base,
1177 has_mask,
1178 ncpsg);
1179
1180 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1181 if (!res.pipeline) {
1182 ggml_metal_cv_t cv = ggml_metal_cv_init();
1183
1184 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1185 //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1186 //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1187 //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1188
1189 //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1190 //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1191 //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1192 //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1193 //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1194 ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1195
1196 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1197
1198 ggml_metal_cv_free(cv);
1199 }
1200
1201 return res;
1202}
1203
1204ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1205 ggml_metal_library_t lib,
1206 const struct ggml_tensor * op,
1207 int32_t nqptg,
1208 int32_t ncpsg) {
1209 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1210 GGML_UNUSED(op);
1211
1212 char base[256];
1213 char name[256];
1214
1215 snprintf(base, 256, "kernel_%s",
1216 "flash_attn_ext_blk");
1217
1218 snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1219 base,
1220 nqptg,
1221 ncpsg);
1222
1223 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1224 if (!res.pipeline) {
1225 ggml_metal_cv_t cv = ggml_metal_cv_init();
1226
1227 //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1228 //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1229 //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1230 //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1231
1232 //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1233 //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1234 //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1235 //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1236 ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1237 ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1238
1239 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1240
1241 ggml_metal_cv_free(cv);
1242 }
1243
1244 return res;
1245}
1246
1247ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
1248 ggml_metal_library_t lib,
1249 const ggml_tensor * op,
1250 bool has_mask,
1251 bool has_sinks,
1252 bool has_bias,
1253 bool has_scap,
1254 bool has_kvpad,
1255 int32_t nsg) {
1256 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1257
1258 char base[256];
1259 char name[256];
1260
1261 const int32_t dk = (int32_t) op->src[1]->ne[0];
1262 const int32_t dv = (int32_t) op->src[2]->ne[0];
1263
1264 const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1265 const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1266
1267 // do bounds checks for the mask?
1268 const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1269
1270 snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1271 "flash_attn_ext",
1272 ggml_type_name(op->src[1]->type),
1273 dk,
1274 dv);
1275
1276 snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
1277 base,
1278 has_mask,
1279 has_sinks,
1280 has_bias,
1281 has_scap,
1282 has_kvpad,
1283 bc_mask,
1284 ns10,
1285 ns20,
1286 nsg);
1287
1288 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1289 if (!res.pipeline) {
1290 ggml_metal_cv_t cv = ggml_metal_cv_init();
1291
1292 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1293 ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1294 ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1295 ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1296 ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1297
1298 ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1299
1300 ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1301 ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1302 ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1303
1304 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1305
1306 ggml_metal_cv_free(cv);
1307 }
1308
1309 return res;
1310}
1311
1312ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1313 ggml_metal_library_t lib,
1314 const ggml_tensor * op,
1315 bool has_mask,
1316 bool has_sinks,
1317 bool has_bias,
1318 bool has_scap,
1319 bool has_kvpad,
1320 int32_t nsg,
1321 int32_t nwg) {
1322 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1323
1324 char base[256];
1325 char name[256];
1326
1327 const int32_t dk = (int32_t) op->src[1]->ne[0];
1328 const int32_t dv = (int32_t) op->src[2]->ne[0];
1329
1330 const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1331 const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1332
1333 snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1334 "flash_attn_ext_vec",
1335 ggml_type_name(op->src[1]->type),
1336 dk,
1337 dv);
1338
1339 snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1340 base,
1341 has_mask,
1342 has_sinks,
1343 has_bias,
1344 has_scap,
1345 has_kvpad,
1346 ns10,
1347 ns20,
1348 nsg, nwg);
1349
1350 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1351 if (!res.pipeline) {
1352 ggml_metal_cv_t cv = ggml_metal_cv_init();
1353
1354 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1355 ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1356 ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1357 ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1358 ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1359
1360 ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1361 ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1362 ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1363 ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1364
1365 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1366
1367 ggml_metal_cv_free(cv);
1368 }
1369
1370 return res;
1371}
1372
1373ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1374 ggml_metal_library_t lib,
1375 const ggml_tensor * op,
1376 int32_t dv,
1377 int32_t nwg) {
1378 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1379
1380 char base[256];
1381 char name[256];
1382
1383 snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1384 snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1385
1386 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1387 if (!res.pipeline) {
1388 ggml_metal_cv_t cv = ggml_metal_cv_init();
1389
1390 ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1391 ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1392
1393 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1394
1395 ggml_metal_cv_free(cv);
1396 }
1397
1398 return res;
1399
1400 GGML_UNUSED(op);
1401}
1402
1403ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1404 char base[256];
1405 char name[256];
1406
1407 int op_num = -1;
1408
1409 switch (op->op) {
1410 case GGML_OP_ADD: op_num = 0; break;
1411 case GGML_OP_SUB: op_num = 1; break;
1412 case GGML_OP_MUL: op_num = 2; break;
1413 case GGML_OP_DIV: op_num = 3; break;
1414 default: GGML_ABORT("fatal error");
1415 };
1416
1417 const char * t0_str = ggml_type_name(op->src[0]->type);
1418 const char * t1_str = ggml_type_name(op->src[1]->type);
1419 const char * t_str = ggml_type_name(op->type);
1420
1421 const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
1422
1423 const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
1424
1425 snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
1426 snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
1427
1428 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1429 if (!res.pipeline) {
1430 ggml_metal_cv_t cv = ggml_metal_cv_init();
1431
1432 ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1433 ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
1434 ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
1435
1436 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1437
1438 ggml_metal_cv_free(cv);
1439 }
1440
1441 res.c4 = is_c4;
1442 res.cnt = is_rb;
1443
1444 return res;
1445}
1446
1447ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
1448 char base[256];
1449 char name[256];
1450
1451 int op_num = -1;
1452
1453 switch (op) {
1454 case GGML_OP_ADD: op_num = 0; break;
1455 case GGML_OP_SUB: op_num = 1; break;
1456 case GGML_OP_MUL: op_num = 2; break;
1457 case GGML_OP_DIV: op_num = 3; break;
1458 default: GGML_ABORT("fatal error");
1459 };
1460
1461 snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
1462 snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
1463
1464 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1465 if (!res.pipeline) {
1466 ggml_metal_cv_t cv = ggml_metal_cv_init();
1467
1468 ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1469 ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
1470 ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
1471
1472 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1473
1474 ggml_metal_cv_free(cv);
1475 }
1476
1477 return res;
1478}
1479
1480ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1481 assert(op->op == GGML_OP_L2_NORM);
1482
1483 char base[256];
1484 char name[256];
1485
1486 const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
1487
1488 const char * t0_str = ggml_type_name(op->src[0]->type);
1489 const char * t_str = ggml_type_name(op->type);
1490
1491 snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
1492 snprintf(name, 256, "%s", base);
1493
1494 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1495 if (!res.pipeline) {
1496 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1497 }
1498
1499 res.c4 = is_c4;
1500 res.smem = 32*sizeof(float);
1501
1502 return res;
1503}
1504
1505ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1506 assert(op->op == GGML_OP_GROUP_NORM);
1507
1508 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1509
1510 char base[256];
1511 char name[256];
1512
1513 snprintf(base, 256, "kernel_group_norm_f32");
1514 snprintf(name, 256, "%s", base);
1515
1516 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1517 if (!res.pipeline) {
1518 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1519 }
1520
1521 res.smem = 32*sizeof(float);
1522
1523 return res;
1524}
1525
1526ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1527 assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
1528
1529 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1530
1531 char base[256];
1532 char name[256];
1533
1534 const char * suffix = "";
1535 if (op->ne[0] % 4 == 0) {
1536 suffix = "_4";
1537 }
1538
1539 switch (op->op) {
1540 case GGML_OP_NORM:
1541 switch (n_fuse) {
1542 case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
1543 case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
1544 case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
1545 default: GGML_ABORT("fatal error");
1546 } break;
1547 case GGML_OP_RMS_NORM:
1548 switch (n_fuse) {
1549 case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
1550 case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
1551 case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
1552 default: GGML_ABORT("fatal error");
1553 } break;
1554 default: GGML_ABORT("fatal error");
1555 }
1556
1557 snprintf(name, 256, "%s", base);
1558
1559 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1560 if (!res.pipeline) {
1561 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1562 }
1563
1564 res.smem = 32*sizeof(float);
1565
1566 return res;
1567}
1568
1569ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1570 assert(op->op == GGML_OP_ROPE);
1571
1572 char base[256];
1573 char name[256];
1574
1575 const int mode = ((const int32_t *) op->op_params)[2];
1576
1577 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1578 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1579 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
1580 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1581
1582 if (is_neox) {
1583 snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1584 } else if ((is_mrope || is_imrope) && !is_vision) {
1585 GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1586 snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
1587 } else if (is_vision) {
1588 GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1589 snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type));
1590 } else {
1591 snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
1592 }
1593
1594 snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1595
1596 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1597 if (!res.pipeline) {
1598 ggml_metal_cv_t cv = ggml_metal_cv_init();
1599
1600 ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1601
1602 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1603
1604 ggml_metal_cv_free(cv);
1605 }
1606
1607 return res;
1608}
1609
1610ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1611 assert(op->op == GGML_OP_IM2COL);
1612
1613 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1614 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1615 GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1616
1617 char base[256];
1618 char name[256];
1619
1620 snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1621 snprintf(name, 256, "%s", base);
1622
1623 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1624 if (!res.pipeline) {
1625 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1626 }
1627
1628 return res;
1629}
1630
1631ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1632 assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
1633
1634 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1635 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1636 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1637 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1638 GGML_ASSERT(op->type == GGML_TYPE_F32);
1639
1640 char base[256];
1641 char name[256];
1642
1643 snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1644 snprintf(name, 256, "%s", base);
1645
1646 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1647 if (!res.pipeline) {
1648 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1649 }
1650
1651 return res;
1652}
1653
1654ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1655 assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1656
1657 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1658 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1659 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1660 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1661 GGML_ASSERT(op->type == GGML_TYPE_F32);
1662
1663 char base[256];
1664 char name[256];
1665
1666 snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1667 snprintf(name, 256, "%s", base);
1668
1669 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1670 if (!res.pipeline) {
1671 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1672 }
1673
1674 return res;
1675}
1676
1677ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1678 assert(op->op == GGML_OP_CONV_2D);
1679
1680 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1681 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1682 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1683 GGML_ASSERT(op->type == GGML_TYPE_F32);
1684
1685 char base[256];
1686 char name[256];
1687
1688 snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1689 snprintf(name, 256, "%s", base);
1690
1691 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1692 if (!res.pipeline) {
1693 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1694 }
1695
1696 return res;
1697}
1698
1699ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1700 assert(op->op == GGML_OP_UPSCALE);
1701
1702 char base[256];
1703 char name[256];
1704
1705 snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
1706 snprintf(name, 256, "%s", base);
1707
1708 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1709 if (!res.pipeline) {
1710 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1711 }
1712
1713 return res;
1714}
1715
1716ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1717 assert(op->op == GGML_OP_PAD);
1718
1719 char base[256];
1720 char name[256];
1721
1722 snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1723 snprintf(name, 256, "%s", base);
1724
1725 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1726 if (res.pipeline) {
1727 return res;
1728 }
1729
1730 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1731
1732 return res;
1733}
1734
1735ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1736 assert(op->op == GGML_OP_PAD_REFLECT_1D);
1737
1738 char base[256];
1739 char name[256];
1740
1741 snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
1742 snprintf(name, 256, "%s", base);
1743
1744 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1745 if (!res.pipeline) {
1746 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1747 }
1748
1749 return res;
1750}
1751
1752ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1753 assert(op->op == GGML_OP_ARANGE);
1754
1755 char base[256];
1756 char name[256];
1757
1758 snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
1759 snprintf(name, 256, "%s", base);
1760
1761 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1762 if (!res.pipeline) {
1763 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1764 }
1765
1766 return res;
1767}
1768
1769ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1770 assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
1771
1772 char base[256];
1773 char name[256];
1774
1775 snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
1776 snprintf(name, 256, "%s", base);
1777
1778 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1779 if (!res.pipeline) {
1780 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1781 }
1782
1783 return res;
1784}
1785
1786ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1787 assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1788
1789 char base[256];
1790 char name[256];
1791
1792 snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1793 snprintf(name, 256, "%s", base);
1794
1795 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1796 if (!res.pipeline) {
1797 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1798 }
1799
1800 return res;
1801}
1802
1803ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1804 assert(op->op == GGML_OP_OPT_STEP_SGD);
1805
1806 char base[256];
1807 char name[256];
1808
1809 snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1810 snprintf(name, 256, "%s", base);
1811
1812 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1813 if (!res.pipeline) {
1814 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1815 }
1816
1817 return res;
1818}
1819
1820ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
1821 GGML_ASSERT(op->type == GGML_TYPE_I64);
1822
1823 char base[256];
1824 char name[256];
1825
1826 snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
1827 snprintf(name, 256, "%s", base);
1828
1829 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1830 if (!res.pipeline) {
1831 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1832 }
1833
1834 return res;
1835}
1836
1837ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
1838 assert(op->op == GGML_OP_COUNT_EQUAL);
1839
1840 GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1841
1842 GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1843 GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
1844 GGML_ASSERT(op->type == GGML_TYPE_I64);
1845
1846 // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1847 GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
1848
1849 char base[256];
1850 char name[256];
1851
1852 int nsg = 1;
1853 while (32*nsg < ne00 && nsg < 32) {
1854 nsg *= 2;
1855 }
1856
1857 snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
1858 snprintf(name, 256, "%s_nsg=%d", base, nsg);
1859
1860 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1861 if (!res.pipeline) {
1862 ggml_metal_cv_t cv = ggml_metal_cv_init();
1863
1864 ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1865
1866 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1867
1868 ggml_metal_cv_free(cv);
1869 }
1870
1871 res.smem = 32 * sizeof(int32_t);
1872 res.nsg = nsg;
1873
1874 return res;
1875}