1#ifndef GGML_METAL_IMPL
2#define GGML_METAL_IMPL
3
4// kernel parameters for mat-vec threadgroups
5//
6// N_R0: number of src0 rows to process per simdgroup
7// N_SG: number of simdgroups per threadgroup
8//
9// TODO: for optimal performance, become function of the device and work size
10
11#define N_R0_Q4_0 4
12#define N_SG_Q4_0 2
13
14#define N_R0_Q4_1 4
15#define N_SG_Q4_1 2
16
17#define N_R0_Q5_0 4
18#define N_SG_Q5_0 2
19
20#define N_R0_Q5_1 4
21#define N_SG_Q5_1 2
22
23#define N_R0_Q8_0 2
24#define N_SG_Q8_0 4
25
26#define N_R0_MXFP4 2
27#define N_SG_MXFP4 2
28
29#define N_R0_Q2_K 4
30#define N_SG_Q2_K 2
31
32#define N_R0_Q3_K 2
33#define N_SG_Q3_K 2
34
35#define N_R0_Q4_K 2
36#define N_SG_Q4_K 2
37
38#define N_R0_Q5_K 2
39#define N_SG_Q5_K 2
40
41#define N_R0_Q6_K 2
42#define N_SG_Q6_K 2
43
44#define N_R0_IQ1_S 4
45#define N_SG_IQ1_S 2
46
47#define N_R0_IQ1_M 4
48#define N_SG_IQ1_M 2
49
50#define N_R0_IQ2_XXS 4
51#define N_SG_IQ2_XXS 2
52
53#define N_R0_IQ2_XS 4
54#define N_SG_IQ2_XS 2
55
56#define N_R0_IQ2_S 4
57#define N_SG_IQ2_S 2
58
59#define N_R0_IQ3_XXS 4
60#define N_SG_IQ3_XXS 2
61
62#define N_R0_IQ3_S 4
63#define N_SG_IQ3_S 2
64
65#define N_R0_IQ4_NL 2
66#define N_SG_IQ4_NL 2
67
68#define N_R0_IQ4_XS 2
69#define N_SG_IQ4_XS 2
70
71// function constants offsets
72#define FC_FLASH_ATTN_EXT_PAD 100
73#define FC_FLASH_ATTN_EXT_BLK 200
74#define FC_FLASH_ATTN_EXT 300
75#define FC_FLASH_ATTN_EXT_VEC 400
76#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
77#define FC_MUL_MV 600
78#define FC_MUL_MM 700
79#define FC_ROPE 800
80#define FC_SSM_CONV 900
81#define FC_SOLVE_TRI 1000
82#define FC_COUNT_EQUAL 1100
83#define FC_UNARY 1200
84#define FC_BIN 1300
85
86// op-specific constants
87#define OP_FLASH_ATTN_EXT_NQPSG 8
88#define OP_FLASH_ATTN_EXT_NCPSG 64
89
90#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
91#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
92
93#define OP_UNARY_NUM_SCALE 10
94#define OP_UNARY_NUM_FILL 11
95#define OP_UNARY_NUM_CLAMP 12
96#define OP_UNARY_NUM_SQR 13
97#define OP_UNARY_NUM_SQRT 14
98#define OP_UNARY_NUM_SIN 15
99#define OP_UNARY_NUM_COS 16
100#define OP_UNARY_NUM_LOG 17
101#define OP_UNARY_NUM_LEAKY_RELU 18
102
103#define OP_UNARY_NUM_TANH 100
104#define OP_UNARY_NUM_RELU 101
105#define OP_UNARY_NUM_SIGMOID 102
106#define OP_UNARY_NUM_GELU 103
107#define OP_UNARY_NUM_GELU_ERF 104
108#define OP_UNARY_NUM_GELU_QUICK 105
109#define OP_UNARY_NUM_SILU 106
110#define OP_UNARY_NUM_ELU 107
111#define OP_UNARY_NUM_NEG 108
112#define OP_UNARY_NUM_ABS 109
113#define OP_UNARY_NUM_SGN 110
114#define OP_UNARY_NUM_STEP 111
115#define OP_UNARY_NUM_HARDSWISH 112
116#define OP_UNARY_NUM_HARDSIGMOID 113
117#define OP_UNARY_NUM_EXP 114
118#define OP_UNARY_NUM_SOFTPLUS 115
119#define OP_UNARY_NUM_EXPM1 116
120
121
122// kernel argument structs
123//
124// - element counters (e.g. ne00) typically use int32_t to reduce register usage
125// however, be careful from int overflows when using those in the kernel implementation
126//
127// - strides (e.g. nb00) use uint64_t
128
129typedef struct {
130 int32_t ne00;
131 int32_t ne01;
132 int32_t ne02;
133 int32_t ne03;
134 uint64_t nb00;
135 uint64_t nb01;
136 uint64_t nb02;
137 uint64_t nb03;
138 int32_t ne10;
139 int32_t ne11;
140 int32_t ne12;
141 int32_t ne13;
142 uint64_t nb10;
143 uint64_t nb11;
144 uint64_t nb12;
145 uint64_t nb13;
146 int32_t ne0;
147 int32_t ne1;
148 int32_t ne2;
149 int32_t ne3;
150 uint64_t nb0;
151 uint64_t nb1;
152 uint64_t nb2;
153 uint64_t nb3;
154 int32_t dim;
155} ggml_metal_kargs_concat;
156
157typedef struct {
158 int32_t ne00;
159 int32_t ne01;
160 int32_t ne02;
161 int32_t ne03;
162 uint64_t nb00;
163 uint64_t nb01;
164 uint64_t nb02;
165 uint64_t nb03;
166 int32_t ne0;
167 int32_t ne1;
168 int32_t ne2;
169 int32_t ne3;
170 uint64_t nb0;
171 uint64_t nb1;
172 uint64_t nb2;
173 uint64_t nb3;
174 float slope;
175 float scale;
176 float bias;
177 float val;
178 float min;
179 float max;
180} ggml_metal_kargs_unary;
181
182typedef struct {
183 int32_t ne00;
184 int32_t ne01;
185 int32_t ne02;
186 int32_t ne03;
187 uint64_t nb00;
188 uint64_t nb01;
189 uint64_t nb02;
190 uint64_t nb03;
191 int32_t ne10;
192 int32_t ne11;
193 int32_t ne12;
194 int32_t ne13;
195 uint64_t nb10;
196 uint64_t nb11;
197 uint64_t nb12;
198 uint64_t nb13;
199 int32_t ne0;
200 int32_t ne1;
201 int32_t ne2;
202 int32_t ne3;
203 uint64_t nb0;
204 uint64_t nb1;
205 uint64_t nb2;
206 uint64_t nb3;
207 uint64_t offs;
208 uint64_t o1[8];
209} ggml_metal_kargs_bin;
210
211typedef struct {
212 int64_t ne0;
213 int64_t ne1;
214 size_t nb01;
215 size_t nb02;
216 size_t nb11;
217 size_t nb21;
218} ggml_metal_kargs_add_id;
219
220typedef struct {
221 int32_t ne00;
222 int32_t ne01;
223 int32_t ne02;
224 int32_t ne03;
225 uint64_t nb00;
226 uint64_t nb01;
227 uint64_t nb02;
228 uint64_t nb03;
229 int32_t ne0;
230 int32_t ne1;
231 int32_t ne2;
232 int32_t ne3;
233 uint64_t nb0;
234 uint64_t nb1;
235 uint64_t nb2;
236 uint64_t nb3;
237} ggml_metal_kargs_repeat;
238
239typedef struct {
240 int64_t nk0;
241 int64_t ne00;
242 int64_t ne01;
243 int64_t ne02;
244 int64_t ne03;
245 uint64_t nb00;
246 uint64_t nb01;
247 uint64_t nb02;
248 uint64_t nb03;
249 int64_t ne0;
250 int64_t ne1;
251 int64_t ne2;
252 int64_t ne3;
253 uint64_t nb0;
254 uint64_t nb1;
255 uint64_t nb2;
256 uint64_t nb3;
257} ggml_metal_kargs_cpy;
258
259typedef struct {
260 int64_t ne10;
261 int64_t ne11;
262 int64_t ne12;
263 uint64_t nb10;
264 uint64_t nb11;
265 uint64_t nb12;
266 uint64_t nb13;
267 uint64_t nb1;
268 uint64_t nb2;
269 uint64_t nb3;
270 uint64_t offs;
271 bool inplace;
272} ggml_metal_kargs_set;
273
274typedef struct {
275 int32_t ne00;
276 int32_t ne01;
277 int32_t ne02;
278 int32_t ne03;
279 uint64_t nb00;
280 uint64_t nb01;
281 uint64_t nb02;
282 uint64_t nb03;
283 int32_t ne0;
284 int32_t ne1;
285 int32_t ne2;
286 int32_t ne3;
287 uint64_t nb0;
288 uint64_t nb1;
289 uint64_t nb2;
290 uint64_t nb3;
291 int32_t n_past;
292 int32_t n_dims;
293 int32_t n_ctx_orig;
294 float freq_base;
295 float freq_scale;
296 float ext_factor;
297 float attn_factor;
298 float beta_fast;
299 float beta_slow;
300 int32_t sect_0;
301 int32_t sect_1;
302 int32_t sect_2;
303 int32_t sect_3;
304 bool src2;
305} ggml_metal_kargs_rope;
306
307typedef struct {
308 int32_t ne11;
309 int32_t ne_12_2; // assume K and V are same shape
310 int32_t ne_12_3;
311 uint64_t nb11;
312 uint64_t nb12;
313 uint64_t nb13;
314 uint64_t nb21;
315 uint64_t nb22;
316 uint64_t nb23;
317 int32_t ne31;
318 int32_t ne32;
319 int32_t ne33;
320 uint64_t nb31;
321 uint64_t nb32;
322 uint64_t nb33;
323} ggml_metal_kargs_flash_attn_ext_pad;
324
325typedef struct {
326 int32_t ne01;
327 int32_t ne30;
328 int32_t ne31;
329 int32_t ne32;
330 int32_t ne33;
331 uint64_t nb31;
332 uint64_t nb32;
333 uint64_t nb33;
334} ggml_metal_kargs_flash_attn_ext_blk;
335
336typedef struct {
337 int32_t ne01;
338 int32_t ne02;
339 int32_t ne03;
340 uint64_t nb01;
341 uint64_t nb02;
342 uint64_t nb03;
343 int32_t ne11;
344 int32_t ne_12_2; // assume K and V are same shape
345 int32_t ne_12_3;
346 int32_t ns10;
347 uint64_t nb11;
348 uint64_t nb12;
349 uint64_t nb13;
350 int32_t ns20;
351 uint64_t nb21;
352 uint64_t nb22;
353 uint64_t nb23;
354 int32_t ne31;
355 int32_t ne32;
356 int32_t ne33;
357 uint64_t nb31;
358 uint64_t nb32;
359 uint64_t nb33;
360 int32_t ne1;
361 int32_t ne2;
362 int32_t ne3;
363 float scale;
364 float max_bias;
365 float m0;
366 float m1;
367 int32_t n_head_log2;
368 float logit_softcap;
369} ggml_metal_kargs_flash_attn_ext;
370
371typedef struct {
372 int32_t ne01;
373 int32_t ne02;
374 int32_t ne03;
375 uint64_t nb01;
376 uint64_t nb02;
377 uint64_t nb03;
378 int32_t ne11;
379 int32_t ne_12_2; // assume K and V are same shape
380 int32_t ne_12_3;
381 int32_t ns10;
382 uint64_t nb11;
383 uint64_t nb12;
384 uint64_t nb13;
385 int32_t ns20;
386 uint64_t nb21;
387 uint64_t nb22;
388 uint64_t nb23;
389 int32_t ne31;
390 int32_t ne32;
391 int32_t ne33;
392 uint64_t nb31;
393 uint64_t nb32;
394 uint64_t nb33;
395 int32_t ne1;
396 int32_t ne2;
397 int32_t ne3;
398 float scale;
399 float max_bias;
400 float m0;
401 float m1;
402 int32_t n_head_log2;
403 float logit_softcap;
404} ggml_metal_kargs_flash_attn_ext_vec;
405
406typedef struct {
407 int32_t nrows;
408} ggml_metal_kargs_flash_attn_ext_vec_reduce;
409
410typedef struct {
411 int32_t ne00;
412 int32_t ne02;
413 uint64_t nb01;
414 uint64_t nb02;
415 uint64_t nb03;
416 int32_t ne12;
417 uint64_t nb10;
418 uint64_t nb11;
419 uint64_t nb12;
420 uint64_t nb13;
421 int32_t ne0;
422 int32_t ne1;
423 int16_t r2;
424 int16_t r3;
425} ggml_metal_kargs_mul_mm;
426
427typedef struct {
428 int32_t ne00;
429 int32_t ne01;
430 int32_t ne02;
431 uint64_t nb00;
432 uint64_t nb01;
433 uint64_t nb02;
434 uint64_t nb03;
435 int32_t ne10;
436 int32_t ne11;
437 int32_t ne12;
438 uint64_t nb10;
439 uint64_t nb11;
440 uint64_t nb12;
441 uint64_t nb13;
442 int32_t ne0;
443 int32_t ne1;
444 int32_t nr0;
445 int16_t r2;
446 int16_t r3;
447} ggml_metal_kargs_mul_mv;
448
449typedef struct {
450 int32_t ne00;
451 int32_t ne01;
452 int32_t ne02;
453 uint64_t nb00;
454 uint64_t nb01;
455 uint64_t nb02;
456 uint64_t nb03;
457 int32_t ne10;
458 int32_t ne11;
459 int32_t ne12;
460 uint64_t nb10;
461 uint64_t nb11;
462 uint64_t nb12;
463 uint64_t nb13;
464 int32_t ne0;
465 int32_t ne1;
466 int16_t r2;
467 int16_t r3;
468} ggml_metal_kargs_mul_mv_ext;
469
470typedef struct {
471 int32_t ne02;
472 int32_t ne10;
473 int32_t ne11; // n_expert_used (bcast)
474 uint64_t nb11;
475 uint64_t nb12;
476 int32_t ne21; // n_tokens
477 int32_t ne20; // n_expert_used
478 uint64_t nb21;
479} ggml_metal_kargs_mul_mm_id_map0;
480
481typedef struct {
482 int32_t ne00;
483 int32_t ne02;
484 uint64_t nb01;
485 uint64_t nb02;
486 uint64_t nb03;
487 int32_t ne11;
488 uint64_t nb10;
489 uint64_t nb11;
490 uint64_t nb12;
491 uint64_t nb13;
492 int32_t ne20;
493 int32_t ne21;
494 int32_t ne0;
495 int32_t ne1;
496 int16_t r2;
497 int16_t r3;
498} ggml_metal_kargs_mul_mm_id;
499
500typedef struct {
501 int32_t nei0;
502 int32_t nei1;
503 uint64_t nbi1;
504 int32_t ne00;
505 int32_t ne01;
506 int32_t ne02;
507 uint64_t nb00;
508 uint64_t nb01;
509 uint64_t nb02;
510 int32_t ne10;
511 int32_t ne11;
512 int32_t ne12;
513 int32_t ne13;
514 uint64_t nb10;
515 uint64_t nb11;
516 uint64_t nb12;
517 int32_t ne0;
518 int32_t ne1;
519 uint64_t nb1;
520 int32_t nr0;
521} ggml_metal_kargs_mul_mv_id;
522
523// NORM
524// RMS_NORM
525typedef struct {
526 int32_t ne00;
527 int32_t ne00_t;
528 uint64_t nb1;
529 uint64_t nb2;
530 uint64_t nb3;
531 float eps;
532 int32_t nef1[3];
533 int32_t nef2[3];
534 int32_t nef3[3];
535 uint64_t nbf1[3];
536 uint64_t nbf2[3];
537 uint64_t nbf3[3];
538} ggml_metal_kargs_norm;
539
540typedef struct {
541 int32_t ne00;
542 int32_t ne01;
543 int32_t ne02;
544 int32_t ne03;
545 uint64_t nb00;
546 uint64_t nb01;
547 uint64_t nb02;
548 uint64_t nb03;
549 int32_t ne0;
550 int32_t ne1;
551 int32_t ne2;
552 int32_t ne3;
553 uint64_t nb0;
554 uint64_t nb1;
555 uint64_t nb2;
556 uint64_t nb3;
557 float eps;
558} ggml_metal_kargs_l2_norm;
559
560typedef struct {
561 int64_t ne00;
562 int64_t ne01;
563 int64_t ne02;
564 uint64_t nb00;
565 uint64_t nb01;
566 uint64_t nb02;
567 int32_t ngrp;
568 float eps;
569} ggml_metal_kargs_group_norm;
570
571typedef struct {
572 int32_t IC;
573 int32_t IL;
574 int32_t K;
575 int32_t s0;
576 uint64_t nb0;
577 uint64_t nb1;
578} ggml_metal_kargs_conv_transpose_1d;
579
580typedef struct {
581 int32_t IC;
582 int32_t IH;
583 int32_t IW;
584 int32_t KH;
585 int32_t KW;
586 int32_t OC;
587 int32_t s0;
588 uint64_t nb0;
589 uint64_t nb1;
590 uint64_t nb2;
591} ggml_metal_kargs_conv_transpose_2d;
592
593typedef struct {
594 uint64_t nb00;
595 uint64_t nb01;
596 uint64_t nb02;
597 uint64_t nb03;
598 uint64_t nb10;
599 uint64_t nb11;
600 uint64_t nb12;
601 uint64_t nb13;
602 uint64_t nb0;
603 uint64_t nb1;
604 uint64_t nb2;
605 uint64_t nb3;
606 int32_t IW;
607 int32_t IH;
608 int32_t KW;
609 int32_t KH;
610 int32_t IC;
611 int32_t OC;
612 int32_t OW;
613 int32_t OH;
614 int32_t N;
615 int32_t s0;
616 int32_t s1;
617 int32_t p0;
618 int32_t p1;
619 int32_t d0;
620 int32_t d1;
621} ggml_metal_kargs_conv_2d;
622
623typedef struct {
624 uint64_t ofs0;
625 uint64_t ofs1;
626 int32_t IW;
627 int32_t IH;
628 int32_t CHW;
629 int32_t s0;
630 int32_t s1;
631 int32_t p0;
632 int32_t p1;
633 int32_t d0;
634 int32_t d1;
635 int32_t N;
636 int32_t KH;
637 int32_t KW;
638 int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
639} ggml_metal_kargs_im2col;
640
641typedef struct{
642 int32_t ne00;
643 uint64_t nb01;
644 int32_t ne10;
645 uint64_t nb11;
646 int32_t ne0;
647 uint64_t nb1;
648 int32_t i00;
649 int32_t i10;
650 float alpha;
651 float limit;
652} ggml_metal_kargs_glu;
653
654typedef struct {
655 uint64_t np;
656} ggml_metal_kargs_sum;
657
658typedef struct {
659 int64_t ne00;
660 int64_t ne01;
661 int64_t ne02;
662 int64_t ne03;
663 uint64_t nb00;
664 uint64_t nb01;
665 uint64_t nb02;
666 uint64_t nb03;
667 int64_t ne0;
668 int64_t ne1;
669 int64_t ne2;
670 int64_t ne3;
671 uint64_t nb0;
672 uint64_t nb1;
673 uint64_t nb2;
674 uint64_t nb3;
675} ggml_metal_kargs_sum_rows;
676
677typedef struct {
678 int64_t ne00;
679 int64_t ne01;
680 int64_t ne02;
681 int64_t ne03;
682 uint64_t nb00;
683 uint64_t nb01;
684 uint64_t nb02;
685 uint64_t nb03;
686 int64_t net0;
687 int64_t net1;
688 int64_t net2;
689 int64_t net3;
690 uint64_t nbt0;
691 uint64_t nbt1;
692 uint64_t nbt2;
693 uint64_t nbt3;
694 bool outb;
695} ggml_metal_kargs_cumsum_blk;
696
697typedef struct {
698 int64_t ne00;
699 int64_t ne01;
700 int64_t ne02;
701 int64_t ne03;
702 uint64_t nb00;
703 uint64_t nb01;
704 uint64_t nb02;
705 uint64_t nb03;
706 int64_t net0;
707 int64_t net1;
708 int64_t net2;
709 int64_t net3;
710 uint64_t nbt0;
711 uint64_t nbt1;
712 uint64_t nbt2;
713 uint64_t nbt3;
714} ggml_metal_kargs_cumsum_add;
715
716typedef struct {
717 int32_t ne00;
718 int32_t ne01;
719 int32_t ne02;
720 uint64_t nb01;
721 uint64_t nb02;
722 uint64_t nb03;
723 int32_t ne11;
724 int32_t ne12;
725 int32_t ne13;
726 uint64_t nb11;
727 uint64_t nb12;
728 uint64_t nb13;
729 uint64_t nb1;
730 uint64_t nb2;
731 uint64_t nb3;
732 float scale;
733 float max_bias;
734 float m0;
735 float m1;
736 int32_t n_head_log2;
737} ggml_metal_kargs_soft_max;
738
739typedef struct {
740 int64_t ne00;
741 int64_t ne01;
742 int64_t ne02;
743 uint64_t nb00;
744 uint64_t nb01;
745 uint64_t nb02;
746 int64_t ne10;
747 int64_t ne11;
748 uint64_t nb10;
749 uint64_t nb11;
750 int64_t ne0;
751 int64_t ne1;
752 int64_t ne2;
753 uint64_t nb0;
754 uint64_t nb1;
755 uint64_t nb2;
756} ggml_metal_kargs_ssm_conv;
757
758typedef struct {
759 int64_t d_state;
760 int64_t d_inner;
761 int64_t n_head;
762 int64_t n_group;
763 int64_t n_seq_tokens;
764 int64_t n_seqs;
765 uint64_t s_off;
766 uint64_t nb00;
767 uint64_t nb01;
768 uint64_t nb02;
769 uint64_t nb03;
770 uint64_t nb10;
771 uint64_t nb11;
772 uint64_t nb12;
773 uint64_t ns12;
774 uint64_t nb13;
775 uint64_t nb20;
776 uint64_t nb21;
777 uint64_t ns21;
778 uint64_t nb22;
779 int64_t ne30;
780 uint64_t nb31;
781 uint64_t nb41;
782 uint64_t nb42;
783 uint64_t ns42;
784 uint64_t nb43;
785 uint64_t nb51;
786 uint64_t nb52;
787 uint64_t ns52;
788 uint64_t nb53;
789 uint64_t nb0;
790} ggml_metal_kargs_ssm_scan;
791
792typedef struct {
793 int32_t ne00;
794 int32_t ne01;
795 int32_t ne02;
796 int32_t ne03;
797 uint64_t nb00;
798 uint64_t nb01;
799 uint64_t nb02;
800 uint64_t nb03;
801 int32_t ne10;
802 int32_t ne11;
803 int32_t ne12;
804 int32_t ne13;
805 uint64_t nb10;
806 uint64_t nb11;
807 uint64_t nb12;
808 uint64_t nb13;
809 int32_t ne0;
810 int32_t ne1;
811 int32_t ne2;
812 int32_t ne3;
813 uint64_t nb0;
814 uint64_t nb1;
815 uint64_t nb2;
816 uint64_t nb3;
817} ggml_metal_kargs_solve_tri;
818
819typedef struct {
820 int32_t ne00t;
821 int32_t ne00;
822 uint64_t nb01;
823 uint64_t nb02;
824 uint64_t nb03;
825 int32_t ne10;
826 uint64_t nb10;
827 uint64_t nb11;
828 uint64_t nb12;
829 uint64_t nb1;
830 uint64_t nb2;
831 uint64_t nb3;
832} ggml_metal_kargs_get_rows;
833
834typedef struct {
835 int32_t nk0;
836 int32_t ne01;
837 uint64_t nb01;
838 uint64_t nb02;
839 uint64_t nb03;
840 int32_t ne11;
841 int32_t ne12;
842 uint64_t nb10;
843 uint64_t nb11;
844 uint64_t nb12;
845 uint64_t nb1;
846 uint64_t nb2;
847 uint64_t nb3;
848} ggml_metal_kargs_set_rows;
849
850typedef struct {
851 int32_t ne00;
852 int32_t ne01;
853 int32_t ne02;
854 int32_t ne03;
855 uint64_t nb00;
856 uint64_t nb01;
857 uint64_t nb02;
858 uint64_t nb03;
859 int32_t ne0;
860 int32_t ne1;
861 int32_t ne2;
862 int32_t ne3;
863 uint64_t nb0;
864 uint64_t nb1;
865 uint64_t nb2;
866 uint64_t nb3;
867} ggml_metal_kargs_diag;
868
869typedef struct {
870 int64_t ne00;
871 int64_t ne01;
872 int64_t ne02;
873 int64_t ne03;
874 uint64_t nb00;
875 uint64_t nb01;
876 uint64_t nb02;
877 uint64_t nb03;
878 int64_t ne0;
879 int64_t ne1;
880 int64_t ne2;
881 int64_t ne3;
882 uint64_t nb0;
883 uint64_t nb1;
884 uint64_t nb2;
885 uint64_t nb3;
886 float sf0;
887 float sf1;
888 float sf2;
889 float sf3;
890} ggml_metal_kargs_upscale;
891
892typedef struct {
893 int64_t ne00;
894 int64_t ne01;
895 int64_t ne02;
896 int64_t ne03;
897 uint64_t nb00;
898 uint64_t nb01;
899 uint64_t nb02;
900 uint64_t nb03;
901 int64_t ne0;
902 int64_t ne1;
903 int64_t ne2;
904 int64_t ne3;
905 uint64_t nb0;
906 uint64_t nb1;
907 uint64_t nb2;
908 uint64_t nb3;
909} ggml_metal_kargs_pad;
910
911typedef struct {
912 int64_t ne00;
913 int64_t ne01;
914 int64_t ne02;
915 int64_t ne03;
916 uint64_t nb00;
917 uint64_t nb01;
918 uint64_t nb02;
919 uint64_t nb03;
920 int64_t ne0;
921 int64_t ne1;
922 int64_t ne2;
923 int64_t ne3;
924 uint64_t nb0;
925 uint64_t nb1;
926 uint64_t nb2;
927 uint64_t nb3;
928 int32_t p0;
929 int32_t p1;
930} ggml_metal_kargs_pad_reflect_1d;
931
932typedef struct {
933 uint64_t nb1;
934 int dim;
935 int max_period;
936} ggml_metal_kargs_timestep_embedding;
937
938typedef struct {
939 int32_t ne00;
940 int32_t ne01;
941 int32_t ne02;
942 int32_t ne03;
943 uint64_t nb00;
944 uint64_t nb01;
945 uint64_t nb02;
946 uint64_t nb03;
947 int32_t ne0;
948 int32_t ne1;
949 int32_t ne2;
950 int32_t ne3;
951 uint64_t nb0;
952 uint64_t nb1;
953 uint64_t nb2;
954 uint64_t nb3;
955} ggml_metal_kargs_tri;
956
957typedef struct {
958 int32_t ne00;
959 int32_t ne01;
960 int32_t ne02;
961 int32_t ne03;
962 uint64_t nb00;
963 uint64_t nb01;
964 uint64_t nb02;
965 uint64_t nb03;
966 int32_t ne0;
967 int32_t ne1;
968 int32_t ne2;
969 int32_t ne3;
970 int32_t top_k;
971} ggml_metal_kargs_argsort;
972
973typedef struct {
974 int64_t ne00;
975 int64_t ne01;
976 int64_t ne02;
977 int64_t ne03;
978 uint64_t nb00;
979 uint64_t nb01;
980 uint64_t nb02;
981 uint64_t nb03;
982 int32_t ne0;
983 int32_t ne1;
984 int32_t ne2;
985 int32_t ne3;
986 int32_t top_k;
987 int32_t len;
988} ggml_metal_kargs_argsort_merge;
989
990typedef struct {
991 int64_t ne0;
992 float start;
993 float step;
994} ggml_metal_kargs_arange;
995
996typedef struct {
997 int64_t val;
998} ggml_metal_kargs_memset;
999
1000typedef struct {
1001 int32_t ne00;
1002 int32_t ne01;
1003 int32_t ne02;
1004 int32_t ne03;
1005 uint64_t nb00;
1006 uint64_t nb01;
1007 uint64_t nb02;
1008 uint64_t nb03;
1009 uint64_t nb10;
1010 uint64_t nb11;
1011 uint64_t nb12;
1012 uint64_t nb13;
1013} ggml_metal_kargs_count_equal;
1014
1015typedef struct {
1016 int32_t k0;
1017 int32_t k1;
1018 int32_t s0;
1019 int32_t s1;
1020 int32_t p0;
1021 int32_t p1;
1022 int64_t IH;
1023 int64_t IW;
1024 int64_t OH;
1025 int64_t OW;
1026 int64_t np;
1027} ggml_metal_kargs_pool_2d;
1028
1029typedef struct {
1030 int32_t k0;
1031 int32_t s0;
1032 int32_t p0;
1033 int64_t IW;
1034 int64_t OW;
1035 int64_t np;
1036} ggml_metal_kargs_pool_1d;
1037
1038typedef struct {
1039 int64_t ne00;
1040 uint64_t nb01;
1041} ggml_metal_kargs_argmax;
1042
1043typedef struct {
1044 int64_t np;
1045} ggml_metal_kargs_opt_step_adamw;
1046
1047typedef struct {
1048 int64_t np;
1049} ggml_metal_kargs_opt_step_sgd;
1050
1051#endif // GGML_METAL_IMPL