summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h')
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h1051
1 files changed, 1051 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
new file mode 100644
index 0000000..952e1be
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -0,0 +1,1051 @@
+#ifndef GGML_METAL_IMPL
+#define GGML_METAL_IMPL
+
+// kernel parameters for mat-vec threadgroups
+//
+// N_R0: number of src0 rows to process per simdgroup
+// N_SG: number of simdgroups per threadgroup
+//
+// TODO: for optimal performance, become function of the device and work size
+
+#define N_R0_Q4_0 4
+#define N_SG_Q4_0 2
+
+#define N_R0_Q4_1 4
+#define N_SG_Q4_1 2
+
+#define N_R0_Q5_0 4
+#define N_SG_Q5_0 2
+
+#define N_R0_Q5_1 4
+#define N_SG_Q5_1 2
+
+#define N_R0_Q8_0 2
+#define N_SG_Q8_0 4
+
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+
+#define N_R0_Q2_K 4
+#define N_SG_Q2_K 2
+
+#define N_R0_Q3_K 2
+#define N_SG_Q3_K 2
+
+#define N_R0_Q4_K 2
+#define N_SG_Q4_K 2
+
+#define N_R0_Q5_K 2
+#define N_SG_Q5_K 2
+
+#define N_R0_Q6_K 2
+#define N_SG_Q6_K 2
+
+#define N_R0_IQ1_S 4
+#define N_SG_IQ1_S 2
+
+#define N_R0_IQ1_M 4
+#define N_SG_IQ1_M 2
+
+#define N_R0_IQ2_XXS 4
+#define N_SG_IQ2_XXS 2
+
+#define N_R0_IQ2_XS 4
+#define N_SG_IQ2_XS 2
+
+#define N_R0_IQ2_S 4
+#define N_SG_IQ2_S 2
+
+#define N_R0_IQ3_XXS 4
+#define N_SG_IQ3_XXS 2
+
+#define N_R0_IQ3_S 4
+#define N_SG_IQ3_S 2
+
+#define N_R0_IQ4_NL 2
+#define N_SG_IQ4_NL 2
+
+#define N_R0_IQ4_XS 2
+#define N_SG_IQ4_XS 2
+
+// function constants offsets
+#define FC_FLASH_ATTN_EXT_PAD 100
+#define FC_FLASH_ATTN_EXT_BLK 200
+#define FC_FLASH_ATTN_EXT 300
+#define FC_FLASH_ATTN_EXT_VEC 400
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
+#define FC_MUL_MV 600
+#define FC_MUL_MM 700
+#define FC_ROPE 800
+#define FC_SSM_CONV 900
+#define FC_SOLVE_TRI 1000
+#define FC_COUNT_EQUAL 1100
+#define FC_UNARY 1200
+#define FC_BIN 1300
+
+// op-specific constants
+#define OP_FLASH_ATTN_EXT_NQPSG 8
+#define OP_FLASH_ATTN_EXT_NCPSG 64
+
+#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
+#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
+
+#define OP_UNARY_NUM_SCALE 10
+#define OP_UNARY_NUM_FILL 11
+#define OP_UNARY_NUM_CLAMP 12
+#define OP_UNARY_NUM_SQR 13
+#define OP_UNARY_NUM_SQRT 14
+#define OP_UNARY_NUM_SIN 15
+#define OP_UNARY_NUM_COS 16
+#define OP_UNARY_NUM_LOG 17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH 100
+#define OP_UNARY_NUM_RELU 101
+#define OP_UNARY_NUM_SIGMOID 102
+#define OP_UNARY_NUM_GELU 103
+#define OP_UNARY_NUM_GELU_ERF 104
+#define OP_UNARY_NUM_GELU_QUICK 105
+#define OP_UNARY_NUM_SILU 106
+#define OP_UNARY_NUM_ELU 107
+#define OP_UNARY_NUM_NEG 108
+#define OP_UNARY_NUM_ABS 109
+#define OP_UNARY_NUM_SGN 110
+#define OP_UNARY_NUM_STEP 111
+#define OP_UNARY_NUM_HARDSWISH 112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP 114
+#define OP_UNARY_NUM_SOFTPLUS 115
+#define OP_UNARY_NUM_EXPM1 116
+
+
+// kernel argument structs
+//
+// - element counters (e.g. ne00) typically use int32_t to reduce register usage
+// however, be careful from int overflows when using those in the kernel implementation
+//
+// - strides (e.g. nb00) use uint64_t
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ int32_t dim;
+} ggml_metal_kargs_concat;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float slope;
+ float scale;
+ float bias;
+ float val;
+ float min;
+ float max;
+} ggml_metal_kargs_unary;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ uint64_t offs;
+ uint64_t o1[8];
+} ggml_metal_kargs_bin;
+
+typedef struct {
+ int64_t ne0;
+ int64_t ne1;
+ size_t nb01;
+ size_t nb02;
+ size_t nb11;
+ size_t nb21;
+} ggml_metal_kargs_add_id;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_repeat;
+
+typedef struct {
+ int64_t nk0;
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ int64_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_cpy;
+
+typedef struct {
+ int64_t ne10;
+ int64_t ne11;
+ int64_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ uint64_t offs;
+ bool inplace;
+} ggml_metal_kargs_set;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ int32_t n_past;
+ int32_t n_dims;
+ int32_t n_ctx_orig;
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+ int32_t sect_0;
+ int32_t sect_1;
+ int32_t sect_2;
+ int32_t sect_3;
+ bool src2;
+} ggml_metal_kargs_rope;
+
+typedef struct {
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ int32_t ne31;
+ int32_t ne32;
+ int32_t ne33;
+ uint64_t nb31;
+ uint64_t nb32;
+ uint64_t nb33;
+} ggml_metal_kargs_flash_attn_ext_pad;
+
+typedef struct {
+ int32_t ne01;
+ int32_t ne30;
+ int32_t ne31;
+ int32_t ne32;
+ int32_t ne33;
+ uint64_t nb31;
+ uint64_t nb32;
+ uint64_t nb33;
+} ggml_metal_kargs_flash_attn_ext_blk;
+
+typedef struct {
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ int32_t ns10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ns20;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ int32_t ne31;
+ int32_t ne32;
+ int32_t ne33;
+ uint64_t nb31;
+ uint64_t nb32;
+ uint64_t nb33;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ int32_t n_head_log2;
+ float logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
+
+typedef struct {
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ int32_t ns10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ns20;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ int32_t ne31;
+ int32_t ne32;
+ int32_t ne33;
+ uint64_t nb31;
+ uint64_t nb32;
+ uint64_t nb33;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ int32_t n_head_log2;
+ float logit_softcap;
+} ggml_metal_kargs_flash_attn_ext_vec;
+
+typedef struct {
+ int32_t nrows;
+} ggml_metal_kargs_flash_attn_ext_vec_reduce;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne02;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int16_t r2;
+ int16_t r3;
+} ggml_metal_kargs_mul_mm;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t nr0;
+ int16_t r2;
+ int16_t r3;
+} ggml_metal_kargs_mul_mv;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int16_t r2;
+ int16_t r3;
+} ggml_metal_kargs_mul_mv_ext;
+
+typedef struct {
+ int32_t ne02;
+ int32_t ne10;
+ int32_t ne11; // n_expert_used (bcast)
+ uint64_t nb11;
+ uint64_t nb12;
+ int32_t ne21; // n_tokens
+ int32_t ne20; // n_expert_used
+ uint64_t nb21;
+} ggml_metal_kargs_mul_mm_id_map0;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne02;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne20;
+ int32_t ne21;
+ int32_t ne0;
+ int32_t ne1;
+ int16_t r2;
+ int16_t r3;
+} ggml_metal_kargs_mul_mm_id;
+
+typedef struct {
+ int32_t nei0;
+ int32_t nei1;
+ uint64_t nbi1;
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ int32_t ne0;
+ int32_t ne1;
+ uint64_t nb1;
+ int32_t nr0;
+} ggml_metal_kargs_mul_mv_id;
+
+// NORM
+// RMS_NORM
+typedef struct {
+ int32_t ne00;
+ int32_t ne00_t;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float eps;
+ int32_t nef1[3];
+ int32_t nef2[3];
+ int32_t nef3[3];
+ uint64_t nbf1[3];
+ uint64_t nbf2[3];
+ uint64_t nbf3[3];
+} ggml_metal_kargs_norm;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float eps;
+} ggml_metal_kargs_l2_norm;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ int32_t ngrp;
+ float eps;
+} ggml_metal_kargs_group_norm;
+
+typedef struct {
+ int32_t IC;
+ int32_t IL;
+ int32_t K;
+ int32_t s0;
+ uint64_t nb0;
+ uint64_t nb1;
+} ggml_metal_kargs_conv_transpose_1d;
+
+typedef struct {
+ int32_t IC;
+ int32_t IH;
+ int32_t IW;
+ int32_t KH;
+ int32_t KW;
+ int32_t OC;
+ int32_t s0;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+} ggml_metal_kargs_conv_transpose_2d;
+
+typedef struct {
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ int32_t IW;
+ int32_t IH;
+ int32_t KW;
+ int32_t KH;
+ int32_t IC;
+ int32_t OC;
+ int32_t OW;
+ int32_t OH;
+ int32_t N;
+ int32_t s0;
+ int32_t s1;
+ int32_t p0;
+ int32_t p1;
+ int32_t d0;
+ int32_t d1;
+} ggml_metal_kargs_conv_2d;
+
+typedef struct {
+ uint64_t ofs0;
+ uint64_t ofs1;
+ int32_t IW;
+ int32_t IH;
+ int32_t CHW;
+ int32_t s0;
+ int32_t s1;
+ int32_t p0;
+ int32_t p1;
+ int32_t d0;
+ int32_t d1;
+ int32_t N;
+ int32_t KH;
+ int32_t KW;
+ int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
+} ggml_metal_kargs_im2col;
+
+typedef struct{
+ int32_t ne00;
+ uint64_t nb01;
+ int32_t ne10;
+ uint64_t nb11;
+ int32_t ne0;
+ uint64_t nb1;
+ int32_t i00;
+ int32_t i10;
+ float alpha;
+ float limit;
+} ggml_metal_kargs_glu;
+
+typedef struct {
+ uint64_t np;
+} ggml_metal_kargs_sum;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ int64_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_sum_rows;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t net0;
+ int64_t net1;
+ int64_t net2;
+ int64_t net3;
+ uint64_t nbt0;
+ uint64_t nbt1;
+ uint64_t nbt2;
+ uint64_t nbt3;
+ bool outb;
+} ggml_metal_kargs_cumsum_blk;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t net0;
+ int64_t net1;
+ int64_t net2;
+ int64_t net3;
+ uint64_t nbt0;
+ uint64_t nbt1;
+ uint64_t nbt2;
+ uint64_t nbt3;
+} ggml_metal_kargs_cumsum_add;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ int32_t n_head_log2;
+} ggml_metal_kargs_soft_max;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ int64_t ne10;
+ int64_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+} ggml_metal_kargs_ssm_conv;
+
+typedef struct {
+ int64_t d_state;
+ int64_t d_inner;
+ int64_t n_head;
+ int64_t n_group;
+ int64_t n_seq_tokens;
+ int64_t n_seqs;
+ uint64_t s_off;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t ns12;
+ uint64_t nb13;
+ uint64_t nb20;
+ uint64_t nb21;
+ uint64_t ns21;
+ uint64_t nb22;
+ int64_t ne30;
+ uint64_t nb31;
+ uint64_t nb41;
+ uint64_t nb42;
+ uint64_t ns42;
+ uint64_t nb43;
+ uint64_t nb51;
+ uint64_t nb52;
+ uint64_t ns52;
+ uint64_t nb53;
+ uint64_t nb0;
+} ggml_metal_kargs_ssm_scan;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
+typedef struct {
+ int32_t ne00t;
+ int32_t ne00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_get_rows;
+
+typedef struct {
+ int32_t nk0;
+ int32_t ne01;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne12;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_set_rows;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_diag;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ int64_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float sf0;
+ float sf1;
+ float sf2;
+ float sf3;
+} ggml_metal_kargs_upscale;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ int64_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_pad;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int64_t ne0;
+ int64_t ne1;
+ int64_t ne2;
+ int64_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ int32_t p0;
+ int32_t p1;
+} ggml_metal_kargs_pad_reflect_1d;
+
+typedef struct {
+ uint64_t nb1;
+ int dim;
+ int max_period;
+} ggml_metal_kargs_timestep_embedding;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_tri;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ int32_t top_k;
+} ggml_metal_kargs_argsort;
+
+typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ int32_t top_k;
+ int32_t len;
+} ggml_metal_kargs_argsort_merge;
+
+typedef struct {
+ int64_t ne0;
+ float start;
+ float step;
+} ggml_metal_kargs_arange;
+
+typedef struct {
+ int64_t val;
+} ggml_metal_kargs_memset;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+} ggml_metal_kargs_count_equal;
+
+typedef struct {
+ int32_t k0;
+ int32_t k1;
+ int32_t s0;
+ int32_t s1;
+ int32_t p0;
+ int32_t p1;
+ int64_t IH;
+ int64_t IW;
+ int64_t OH;
+ int64_t OW;
+ int64_t np;
+} ggml_metal_kargs_pool_2d;
+
+typedef struct {
+ int32_t k0;
+ int32_t s0;
+ int32_t p0;
+ int64_t IW;
+ int64_t OW;
+ int64_t np;
+} ggml_metal_kargs_pool_1d;
+
+typedef struct {
+ int64_t ne00;
+ uint64_t nb01;
+} ggml_metal_kargs_argmax;
+
+typedef struct {
+ int64_t np;
+} ggml_metal_kargs_opt_step_adamw;
+
+typedef struct {
+ int64_t np;
+} ggml_metal_kargs_opt_step_sgd;
+
+#endif // GGML_METAL_IMPL