1#ifndef GGML_METAL_IMPL
   2#define GGML_METAL_IMPL
   3
   4// kernel parameters for mat-vec threadgroups
   5//
   6// N_R0: number of src0 rows to process per simdgroup
   7// N_SG: number of simdgroups per threadgroup
   8//
   9// TODO: for optimal performance, become function of the device and work size
  10
  11#define N_R0_Q4_0 4
  12#define N_SG_Q4_0 2
  13
  14#define N_R0_Q4_1 4
  15#define N_SG_Q4_1 2
  16
  17#define N_R0_Q5_0 4
  18#define N_SG_Q5_0 2
  19
  20#define N_R0_Q5_1 4
  21#define N_SG_Q5_1 2
  22
  23#define N_R0_Q8_0 2
  24#define N_SG_Q8_0 4
  25
  26#define N_R0_MXFP4 2
  27#define N_SG_MXFP4 2
  28
  29#define N_R0_Q2_K 4
  30#define N_SG_Q2_K 2
  31
  32#define N_R0_Q3_K 2
  33#define N_SG_Q3_K 2
  34
  35#define N_R0_Q4_K 2
  36#define N_SG_Q4_K 2
  37
  38#define N_R0_Q5_K 2
  39#define N_SG_Q5_K 2
  40
  41#define N_R0_Q6_K 2
  42#define N_SG_Q6_K 2
  43
  44#define N_R0_IQ1_S 4
  45#define N_SG_IQ1_S 2
  46
  47#define N_R0_IQ1_M 4
  48#define N_SG_IQ1_M 2
  49
  50#define N_R0_IQ2_XXS 4
  51#define N_SG_IQ2_XXS 2
  52
  53#define N_R0_IQ2_XS 4
  54#define N_SG_IQ2_XS 2
  55
  56#define N_R0_IQ2_S 4
  57#define N_SG_IQ2_S 2
  58
  59#define N_R0_IQ3_XXS 4
  60#define N_SG_IQ3_XXS 2
  61
  62#define N_R0_IQ3_S 4
  63#define N_SG_IQ3_S 2
  64
  65#define N_R0_IQ4_NL 2
  66#define N_SG_IQ4_NL 2
  67
  68#define N_R0_IQ4_XS 2
  69#define N_SG_IQ4_XS 2
  70
  71// function constants offsets
  72#define FC_FLASH_ATTN_EXT_PAD          100
  73#define FC_FLASH_ATTN_EXT_BLK          200
  74#define FC_FLASH_ATTN_EXT              300
  75#define FC_FLASH_ATTN_EXT_VEC          400
  76#define FC_FLASH_ATTN_EXT_VEC_REDUCE   500
  77#define FC_MUL_MV                      600
  78#define FC_MUL_MM                      700
  79#define FC_ROPE                        800
  80#define FC_SSM_CONV                    900
  81#define FC_SOLVE_TRI                   1000
  82#define FC_COUNT_EQUAL                 1100
  83#define FC_UNARY                       1200
  84#define FC_BIN                         1300
  85
  86// op-specific constants
  87#define OP_FLASH_ATTN_EXT_NQPSG 8
  88#define OP_FLASH_ATTN_EXT_NCPSG 64
  89
  90#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
  91#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
  92
  93#define OP_UNARY_NUM_SCALE      10
  94#define OP_UNARY_NUM_FILL       11
  95#define OP_UNARY_NUM_CLAMP      12
  96#define OP_UNARY_NUM_SQR        13
  97#define OP_UNARY_NUM_SQRT       14
  98#define OP_UNARY_NUM_SIN        15
  99#define OP_UNARY_NUM_COS        16
 100#define OP_UNARY_NUM_LOG        17
 101#define OP_UNARY_NUM_LEAKY_RELU 18
 102
 103#define OP_UNARY_NUM_TANH        100
 104#define OP_UNARY_NUM_RELU        101
 105#define OP_UNARY_NUM_SIGMOID     102
 106#define OP_UNARY_NUM_GELU        103
 107#define OP_UNARY_NUM_GELU_ERF    104
 108#define OP_UNARY_NUM_GELU_QUICK  105
 109#define OP_UNARY_NUM_SILU        106
 110#define OP_UNARY_NUM_ELU         107
 111#define OP_UNARY_NUM_NEG         108
 112#define OP_UNARY_NUM_ABS         109
 113#define OP_UNARY_NUM_SGN         110
 114#define OP_UNARY_NUM_STEP        111
 115#define OP_UNARY_NUM_HARDSWISH   112
 116#define OP_UNARY_NUM_HARDSIGMOID 113
 117#define OP_UNARY_NUM_EXP         114
 118#define OP_UNARY_NUM_SOFTPLUS    115
 119#define OP_UNARY_NUM_EXPM1       116
 120
 121
 122// kernel argument structs
 123//
 124// - element counters (e.g. ne00) typically use int32_t to reduce register usage
 125//   however, be careful from int overflows when using those in the kernel implementation
 126//
 127// - strides (e.g. nb00) use uint64_t
 128
 129typedef struct {
 130    int32_t  ne00;
 131    int32_t  ne01;
 132    int32_t  ne02;
 133    int32_t  ne03;
 134    uint64_t nb00;
 135    uint64_t nb01;
 136    uint64_t nb02;
 137    uint64_t nb03;
 138    int32_t  ne10;
 139    int32_t  ne11;
 140    int32_t  ne12;
 141    int32_t  ne13;
 142    uint64_t nb10;
 143    uint64_t nb11;
 144    uint64_t nb12;
 145    uint64_t nb13;
 146    int32_t  ne0;
 147    int32_t  ne1;
 148    int32_t  ne2;
 149    int32_t  ne3;
 150    uint64_t nb0;
 151    uint64_t nb1;
 152    uint64_t nb2;
 153    uint64_t nb3;
 154    int32_t  dim;
 155} ggml_metal_kargs_concat;
 156
 157typedef struct {
 158    int32_t  ne00;
 159    int32_t  ne01;
 160    int32_t  ne02;
 161    int32_t  ne03;
 162    uint64_t nb00;
 163    uint64_t nb01;
 164    uint64_t nb02;
 165    uint64_t nb03;
 166    int32_t  ne0;
 167    int32_t  ne1;
 168    int32_t  ne2;
 169    int32_t  ne3;
 170    uint64_t nb0;
 171    uint64_t nb1;
 172    uint64_t nb2;
 173    uint64_t nb3;
 174    float    slope;
 175    float    scale;
 176    float    bias;
 177    float    val;
 178    float    min;
 179    float    max;
 180} ggml_metal_kargs_unary;
 181
 182typedef struct {
 183    int32_t  ne00;
 184    int32_t  ne01;
 185    int32_t  ne02;
 186    int32_t  ne03;
 187    uint64_t nb00;
 188    uint64_t nb01;
 189    uint64_t nb02;
 190    uint64_t nb03;
 191    int32_t  ne10;
 192    int32_t  ne11;
 193    int32_t  ne12;
 194    int32_t  ne13;
 195    uint64_t nb10;
 196    uint64_t nb11;
 197    uint64_t nb12;
 198    uint64_t nb13;
 199    int32_t  ne0;
 200    int32_t  ne1;
 201    int32_t  ne2;
 202    int32_t  ne3;
 203    uint64_t nb0;
 204    uint64_t nb1;
 205    uint64_t nb2;
 206    uint64_t nb3;
 207    uint64_t offs;
 208    uint64_t o1[8];
 209} ggml_metal_kargs_bin;
 210
 211typedef struct {
 212    int64_t ne0;
 213    int64_t ne1;
 214    size_t nb01;
 215    size_t nb02;
 216    size_t nb11;
 217    size_t nb21;
 218} ggml_metal_kargs_add_id;
 219
 220typedef struct {
 221    int32_t  ne00;
 222    int32_t  ne01;
 223    int32_t  ne02;
 224    int32_t  ne03;
 225    uint64_t nb00;
 226    uint64_t nb01;
 227    uint64_t nb02;
 228    uint64_t nb03;
 229    int32_t  ne0;
 230    int32_t  ne1;
 231    int32_t  ne2;
 232    int32_t  ne3;
 233    uint64_t nb0;
 234    uint64_t nb1;
 235    uint64_t nb2;
 236    uint64_t nb3;
 237} ggml_metal_kargs_repeat;
 238
 239typedef struct {
 240    int64_t  nk0;
 241    int64_t  ne00;
 242    int64_t  ne01;
 243    int64_t  ne02;
 244    int64_t  ne03;
 245    uint64_t nb00;
 246    uint64_t nb01;
 247    uint64_t nb02;
 248    uint64_t nb03;
 249    int64_t  ne0;
 250    int64_t  ne1;
 251    int64_t  ne2;
 252    int64_t  ne3;
 253    uint64_t nb0;
 254    uint64_t nb1;
 255    uint64_t nb2;
 256    uint64_t nb3;
 257} ggml_metal_kargs_cpy;
 258
 259typedef struct {
 260    int64_t  ne10;
 261    int64_t  ne11;
 262    int64_t  ne12;
 263    uint64_t nb10;
 264    uint64_t nb11;
 265    uint64_t nb12;
 266    uint64_t nb13;
 267    uint64_t nb1;
 268    uint64_t nb2;
 269    uint64_t nb3;
 270    uint64_t offs;
 271    bool     inplace;
 272} ggml_metal_kargs_set;
 273
 274typedef struct {
 275    int32_t  ne00;
 276    int32_t  ne01;
 277    int32_t  ne02;
 278    int32_t  ne03;
 279    uint64_t nb00;
 280    uint64_t nb01;
 281    uint64_t nb02;
 282    uint64_t nb03;
 283    int32_t  ne0;
 284    int32_t  ne1;
 285    int32_t  ne2;
 286    int32_t  ne3;
 287    uint64_t nb0;
 288    uint64_t nb1;
 289    uint64_t nb2;
 290    uint64_t nb3;
 291    int32_t  n_past;
 292    int32_t  n_dims;
 293    int32_t  n_ctx_orig;
 294    float    freq_base;
 295    float    freq_scale;
 296    float    ext_factor;
 297    float    attn_factor;
 298    float    beta_fast;
 299    float    beta_slow;
 300    int32_t  sect_0;
 301    int32_t  sect_1;
 302    int32_t  sect_2;
 303    int32_t  sect_3;
 304    bool     src2;
 305} ggml_metal_kargs_rope;
 306
 307typedef struct {
 308    int32_t  ne11;
 309    int32_t  ne_12_2; // assume K and V are same shape
 310    int32_t  ne_12_3;
 311    uint64_t nb11;
 312    uint64_t nb12;
 313    uint64_t nb13;
 314    uint64_t nb21;
 315    uint64_t nb22;
 316    uint64_t nb23;
 317    int32_t  ne31;
 318    int32_t  ne32;
 319    int32_t  ne33;
 320    uint64_t nb31;
 321    uint64_t nb32;
 322    uint64_t nb33;
 323} ggml_metal_kargs_flash_attn_ext_pad;
 324
 325typedef struct {
 326    int32_t  ne01;
 327    int32_t  ne30;
 328    int32_t  ne31;
 329    int32_t  ne32;
 330    int32_t  ne33;
 331    uint64_t nb31;
 332    uint64_t nb32;
 333    uint64_t nb33;
 334} ggml_metal_kargs_flash_attn_ext_blk;
 335
 336typedef struct {
 337    int32_t  ne01;
 338    int32_t  ne02;
 339    int32_t  ne03;
 340    uint64_t nb01;
 341    uint64_t nb02;
 342    uint64_t nb03;
 343    int32_t  ne11;
 344    int32_t  ne_12_2; // assume K and V are same shape
 345    int32_t  ne_12_3;
 346    int32_t  ns10;
 347    uint64_t nb11;
 348    uint64_t nb12;
 349    uint64_t nb13;
 350    int32_t  ns20;
 351    uint64_t nb21;
 352    uint64_t nb22;
 353    uint64_t nb23;
 354    int32_t  ne31;
 355    int32_t  ne32;
 356    int32_t  ne33;
 357    uint64_t nb31;
 358    uint64_t nb32;
 359    uint64_t nb33;
 360    int32_t  ne1;
 361    int32_t  ne2;
 362    int32_t  ne3;
 363    float    scale;
 364    float    max_bias;
 365    float    m0;
 366    float    m1;
 367    int32_t  n_head_log2;
 368    float    logit_softcap;
 369} ggml_metal_kargs_flash_attn_ext;
 370
 371typedef struct {
 372    int32_t  ne01;
 373    int32_t  ne02;
 374    int32_t  ne03;
 375    uint64_t nb01;
 376    uint64_t nb02;
 377    uint64_t nb03;
 378    int32_t  ne11;
 379    int32_t  ne_12_2; // assume K and V are same shape
 380    int32_t  ne_12_3;
 381    int32_t  ns10;
 382    uint64_t nb11;
 383    uint64_t nb12;
 384    uint64_t nb13;
 385    int32_t  ns20;
 386    uint64_t nb21;
 387    uint64_t nb22;
 388    uint64_t nb23;
 389    int32_t  ne31;
 390    int32_t  ne32;
 391    int32_t  ne33;
 392    uint64_t nb31;
 393    uint64_t nb32;
 394    uint64_t nb33;
 395    int32_t  ne1;
 396    int32_t  ne2;
 397    int32_t  ne3;
 398    float    scale;
 399    float    max_bias;
 400    float    m0;
 401    float    m1;
 402    int32_t  n_head_log2;
 403    float    logit_softcap;
 404} ggml_metal_kargs_flash_attn_ext_vec;
 405
 406typedef struct {
 407    int32_t  nrows;
 408} ggml_metal_kargs_flash_attn_ext_vec_reduce;
 409
 410typedef struct {
 411    int32_t  ne00;
 412    int32_t  ne02;
 413    uint64_t nb01;
 414    uint64_t nb02;
 415    uint64_t nb03;
 416    int32_t  ne12;
 417    uint64_t nb10;
 418    uint64_t nb11;
 419    uint64_t nb12;
 420    uint64_t nb13;
 421    int32_t  ne0;
 422    int32_t  ne1;
 423    int16_t  r2;
 424    int16_t  r3;
 425} ggml_metal_kargs_mul_mm;
 426
 427typedef struct {
 428    int32_t  ne00;
 429    int32_t  ne01;
 430    int32_t  ne02;
 431    uint64_t nb00;
 432    uint64_t nb01;
 433    uint64_t nb02;
 434    uint64_t nb03;
 435    int32_t  ne10;
 436    int32_t  ne11;
 437    int32_t  ne12;
 438    uint64_t nb10;
 439    uint64_t nb11;
 440    uint64_t nb12;
 441    uint64_t nb13;
 442    int32_t  ne0;
 443    int32_t  ne1;
 444    int32_t  nr0;
 445    int16_t  r2;
 446    int16_t  r3;
 447} ggml_metal_kargs_mul_mv;
 448
 449typedef struct {
 450    int32_t  ne00;
 451    int32_t  ne01;
 452    int32_t  ne02;
 453    uint64_t nb00;
 454    uint64_t nb01;
 455    uint64_t nb02;
 456    uint64_t nb03;
 457    int32_t  ne10;
 458    int32_t  ne11;
 459    int32_t  ne12;
 460    uint64_t nb10;
 461    uint64_t nb11;
 462    uint64_t nb12;
 463    uint64_t nb13;
 464    int32_t  ne0;
 465    int32_t  ne1;
 466    int16_t  r2;
 467    int16_t  r3;
 468} ggml_metal_kargs_mul_mv_ext;
 469
 470typedef struct {
 471    int32_t  ne02;
 472    int32_t  ne10;
 473    int32_t  ne11;  // n_expert_used (bcast)
 474    uint64_t nb11;
 475    uint64_t nb12;
 476    int32_t  ne21; // n_tokens
 477    int32_t  ne20;  // n_expert_used
 478    uint64_t nb21;
 479} ggml_metal_kargs_mul_mm_id_map0;
 480
 481typedef struct {
 482    int32_t  ne00;
 483    int32_t  ne02;
 484    uint64_t nb01;
 485    uint64_t nb02;
 486    uint64_t nb03;
 487    int32_t  ne11;
 488    uint64_t nb10;
 489    uint64_t nb11;
 490    uint64_t nb12;
 491    uint64_t nb13;
 492    int32_t  ne20;
 493    int32_t  ne21;
 494    int32_t  ne0;
 495    int32_t  ne1;
 496    int16_t  r2;
 497    int16_t  r3;
 498} ggml_metal_kargs_mul_mm_id;
 499
 500typedef struct {
 501    int32_t  nei0;
 502    int32_t  nei1;
 503    uint64_t nbi1;
 504    int32_t  ne00;
 505    int32_t  ne01;
 506    int32_t  ne02;
 507    uint64_t nb00;
 508    uint64_t nb01;
 509    uint64_t nb02;
 510    int32_t  ne10;
 511    int32_t  ne11;
 512    int32_t  ne12;
 513    int32_t  ne13;
 514    uint64_t nb10;
 515    uint64_t nb11;
 516    uint64_t nb12;
 517    int32_t  ne0;
 518    int32_t  ne1;
 519    uint64_t nb1;
 520    int32_t  nr0;
 521} ggml_metal_kargs_mul_mv_id;
 522
 523// NORM
 524// RMS_NORM
 525typedef struct {
 526    int32_t  ne00;
 527    int32_t  ne00_t;
 528    uint64_t nb1;
 529    uint64_t nb2;
 530    uint64_t nb3;
 531    float    eps;
 532    int32_t  nef1[3];
 533    int32_t  nef2[3];
 534    int32_t  nef3[3];
 535    uint64_t nbf1[3];
 536    uint64_t nbf2[3];
 537    uint64_t nbf3[3];
 538} ggml_metal_kargs_norm;
 539
 540typedef struct {
 541    int32_t  ne00;
 542    int32_t  ne01;
 543    int32_t  ne02;
 544    int32_t  ne03;
 545    uint64_t nb00;
 546    uint64_t nb01;
 547    uint64_t nb02;
 548    uint64_t nb03;
 549    int32_t  ne0;
 550    int32_t  ne1;
 551    int32_t  ne2;
 552    int32_t  ne3;
 553    uint64_t nb0;
 554    uint64_t nb1;
 555    uint64_t nb2;
 556    uint64_t nb3;
 557    float    eps;
 558} ggml_metal_kargs_l2_norm;
 559
 560typedef struct {
 561    int64_t  ne00;
 562    int64_t  ne01;
 563    int64_t  ne02;
 564    uint64_t nb00;
 565    uint64_t nb01;
 566    uint64_t nb02;
 567    int32_t  ngrp;
 568    float    eps;
 569} ggml_metal_kargs_group_norm;
 570
 571typedef struct {
 572    int32_t  IC;
 573    int32_t  IL;
 574    int32_t  K;
 575    int32_t  s0;
 576    uint64_t nb0;
 577    uint64_t nb1;
 578} ggml_metal_kargs_conv_transpose_1d;
 579
 580typedef struct {
 581    int32_t  IC;
 582    int32_t  IH;
 583    int32_t  IW;
 584    int32_t  KH;
 585    int32_t  KW;
 586    int32_t  OC;
 587    int32_t  s0;
 588    uint64_t nb0;
 589    uint64_t nb1;
 590    uint64_t nb2;
 591} ggml_metal_kargs_conv_transpose_2d;
 592
 593typedef struct {
 594    uint64_t nb00;
 595    uint64_t nb01;
 596    uint64_t nb02;
 597    uint64_t nb03;
 598    uint64_t nb10;
 599    uint64_t nb11;
 600    uint64_t nb12;
 601    uint64_t nb13;
 602    uint64_t nb0;
 603    uint64_t nb1;
 604    uint64_t nb2;
 605    uint64_t nb3;
 606    int32_t  IW;
 607    int32_t  IH;
 608    int32_t  KW;
 609    int32_t  KH;
 610    int32_t  IC;
 611    int32_t  OC;
 612    int32_t  OW;
 613    int32_t  OH;
 614    int32_t  N;
 615    int32_t  s0;
 616    int32_t  s1;
 617    int32_t  p0;
 618    int32_t  p1;
 619    int32_t  d0;
 620    int32_t  d1;
 621} ggml_metal_kargs_conv_2d;
 622
 623typedef struct {
 624    uint64_t  ofs0;
 625    uint64_t  ofs1;
 626    int32_t  IW;
 627    int32_t  IH;
 628    int32_t  CHW;
 629    int32_t  s0;
 630    int32_t  s1;
 631    int32_t  p0;
 632    int32_t  p1;
 633    int32_t  d0;
 634    int32_t  d1;
 635    int32_t  N;
 636    int32_t  KH;
 637    int32_t  KW;
 638    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
 639} ggml_metal_kargs_im2col;
 640
 641typedef struct{
 642    int32_t  ne00;
 643    uint64_t nb01;
 644    int32_t  ne10;
 645    uint64_t nb11;
 646    int32_t  ne0;
 647    uint64_t nb1;
 648    int32_t  i00;
 649    int32_t  i10;
 650    float    alpha;
 651    float    limit;
 652} ggml_metal_kargs_glu;
 653
 654typedef struct {
 655    uint64_t np;
 656} ggml_metal_kargs_sum;
 657
 658typedef struct {
 659    int64_t  ne00;
 660    int64_t  ne01;
 661    int64_t  ne02;
 662    int64_t  ne03;
 663    uint64_t nb00;
 664    uint64_t nb01;
 665    uint64_t nb02;
 666    uint64_t nb03;
 667    int64_t  ne0;
 668    int64_t  ne1;
 669    int64_t  ne2;
 670    int64_t  ne3;
 671    uint64_t nb0;
 672    uint64_t nb1;
 673    uint64_t nb2;
 674    uint64_t nb3;
 675} ggml_metal_kargs_sum_rows;
 676
 677typedef struct {
 678    int64_t  ne00;
 679    int64_t  ne01;
 680    int64_t  ne02;
 681    int64_t  ne03;
 682    uint64_t nb00;
 683    uint64_t nb01;
 684    uint64_t nb02;
 685    uint64_t nb03;
 686    int64_t  net0;
 687    int64_t  net1;
 688    int64_t  net2;
 689    int64_t  net3;
 690    uint64_t nbt0;
 691    uint64_t nbt1;
 692    uint64_t nbt2;
 693    uint64_t nbt3;
 694    bool     outb;
 695} ggml_metal_kargs_cumsum_blk;
 696
 697typedef struct {
 698    int64_t  ne00;
 699    int64_t  ne01;
 700    int64_t  ne02;
 701    int64_t  ne03;
 702    uint64_t nb00;
 703    uint64_t nb01;
 704    uint64_t nb02;
 705    uint64_t nb03;
 706    int64_t  net0;
 707    int64_t  net1;
 708    int64_t  net2;
 709    int64_t  net3;
 710    uint64_t nbt0;
 711    uint64_t nbt1;
 712    uint64_t nbt2;
 713    uint64_t nbt3;
 714} ggml_metal_kargs_cumsum_add;
 715
 716typedef struct {
 717    int32_t  ne00;
 718    int32_t  ne01;
 719    int32_t  ne02;
 720    uint64_t nb01;
 721    uint64_t nb02;
 722    uint64_t nb03;
 723    int32_t  ne11;
 724    int32_t  ne12;
 725    int32_t  ne13;
 726    uint64_t nb11;
 727    uint64_t nb12;
 728    uint64_t nb13;
 729    uint64_t nb1;
 730    uint64_t nb2;
 731    uint64_t nb3;
 732    float    scale;
 733    float    max_bias;
 734    float    m0;
 735    float    m1;
 736    int32_t  n_head_log2;
 737} ggml_metal_kargs_soft_max;
 738
 739typedef struct {
 740    int64_t  ne00;
 741    int64_t  ne01;
 742    int64_t  ne02;
 743    uint64_t nb00;
 744    uint64_t nb01;
 745    uint64_t nb02;
 746    int64_t  ne10;
 747    int64_t  ne11;
 748    uint64_t nb10;
 749    uint64_t nb11;
 750    int64_t  ne0;
 751    int64_t  ne1;
 752    int64_t  ne2;
 753    uint64_t nb0;
 754    uint64_t nb1;
 755    uint64_t nb2;
 756} ggml_metal_kargs_ssm_conv;
 757
 758typedef struct {
 759    int64_t  d_state;
 760    int64_t  d_inner;
 761    int64_t  n_head;
 762    int64_t  n_group;
 763    int64_t  n_seq_tokens;
 764    int64_t  n_seqs;
 765    uint64_t s_off;
 766    uint64_t nb00;
 767    uint64_t nb01;
 768    uint64_t nb02;
 769    uint64_t nb03;
 770    uint64_t nb10;
 771    uint64_t nb11;
 772    uint64_t nb12;
 773    uint64_t ns12;
 774    uint64_t nb13;
 775    uint64_t nb20;
 776    uint64_t nb21;
 777    uint64_t ns21;
 778    uint64_t nb22;
 779    int64_t  ne30;
 780    uint64_t nb31;
 781    uint64_t nb41;
 782    uint64_t nb42;
 783    uint64_t ns42;
 784    uint64_t nb43;
 785    uint64_t nb51;
 786    uint64_t nb52;
 787    uint64_t ns52;
 788    uint64_t nb53;
 789    uint64_t nb0;
 790} ggml_metal_kargs_ssm_scan;
 791
 792typedef struct {
 793    int32_t  ne00;
 794    int32_t  ne01;
 795    int32_t  ne02;
 796    int32_t  ne03;
 797    uint64_t nb00;
 798    uint64_t nb01;
 799    uint64_t nb02;
 800    uint64_t nb03;
 801    int32_t  ne10;
 802    int32_t  ne11;
 803    int32_t  ne12;
 804    int32_t  ne13;
 805    uint64_t nb10;
 806    uint64_t nb11;
 807    uint64_t nb12;
 808    uint64_t nb13;
 809    int32_t  ne0;
 810    int32_t  ne1;
 811    int32_t  ne2;
 812    int32_t  ne3;
 813    uint64_t nb0;
 814    uint64_t nb1;
 815    uint64_t nb2;
 816    uint64_t nb3;
 817} ggml_metal_kargs_solve_tri;
 818
 819typedef struct {
 820    int32_t  ne00t;
 821    int32_t  ne00;
 822    uint64_t nb01;
 823    uint64_t nb02;
 824    uint64_t nb03;
 825    int32_t  ne10;
 826    uint64_t nb10;
 827    uint64_t nb11;
 828    uint64_t nb12;
 829    uint64_t nb1;
 830    uint64_t nb2;
 831    uint64_t nb3;
 832} ggml_metal_kargs_get_rows;
 833
 834typedef struct {
 835    int32_t  nk0;
 836    int32_t  ne01;
 837    uint64_t nb01;
 838    uint64_t nb02;
 839    uint64_t nb03;
 840    int32_t  ne11;
 841    int32_t  ne12;
 842    uint64_t nb10;
 843    uint64_t nb11;
 844    uint64_t nb12;
 845    uint64_t nb1;
 846    uint64_t nb2;
 847    uint64_t nb3;
 848} ggml_metal_kargs_set_rows;
 849
 850typedef struct {
 851    int32_t  ne00;
 852    int32_t  ne01;
 853    int32_t  ne02;
 854    int32_t  ne03;
 855    uint64_t nb00;
 856    uint64_t nb01;
 857    uint64_t nb02;
 858    uint64_t nb03;
 859    int32_t  ne0;
 860    int32_t  ne1;
 861    int32_t  ne2;
 862    int32_t  ne3;
 863    uint64_t nb0;
 864    uint64_t nb1;
 865    uint64_t nb2;
 866    uint64_t nb3;
 867} ggml_metal_kargs_diag;
 868
 869typedef struct {
 870    int64_t  ne00;
 871    int64_t  ne01;
 872    int64_t  ne02;
 873    int64_t  ne03;
 874    uint64_t nb00;
 875    uint64_t nb01;
 876    uint64_t nb02;
 877    uint64_t nb03;
 878    int64_t  ne0;
 879    int64_t  ne1;
 880    int64_t  ne2;
 881    int64_t  ne3;
 882    uint64_t nb0;
 883    uint64_t nb1;
 884    uint64_t nb2;
 885    uint64_t nb3;
 886    float    sf0;
 887    float    sf1;
 888    float    sf2;
 889    float    sf3;
 890} ggml_metal_kargs_upscale;
 891
 892typedef struct {
 893    int64_t  ne00;
 894    int64_t  ne01;
 895    int64_t  ne02;
 896    int64_t  ne03;
 897    uint64_t nb00;
 898    uint64_t nb01;
 899    uint64_t nb02;
 900    uint64_t nb03;
 901    int64_t  ne0;
 902    int64_t  ne1;
 903    int64_t  ne2;
 904    int64_t  ne3;
 905    uint64_t nb0;
 906    uint64_t nb1;
 907    uint64_t nb2;
 908    uint64_t nb3;
 909} ggml_metal_kargs_pad;
 910
 911typedef struct {
 912    int64_t  ne00;
 913    int64_t  ne01;
 914    int64_t  ne02;
 915    int64_t  ne03;
 916    uint64_t nb00;
 917    uint64_t nb01;
 918    uint64_t nb02;
 919    uint64_t nb03;
 920    int64_t  ne0;
 921    int64_t  ne1;
 922    int64_t  ne2;
 923    int64_t  ne3;
 924    uint64_t nb0;
 925    uint64_t nb1;
 926    uint64_t nb2;
 927    uint64_t nb3;
 928    int32_t  p0;
 929    int32_t  p1;
 930} ggml_metal_kargs_pad_reflect_1d;
 931
 932typedef struct {
 933    uint64_t nb1;
 934    int      dim;
 935    int      max_period;
 936} ggml_metal_kargs_timestep_embedding;
 937
 938typedef struct {
 939    int32_t  ne00;
 940    int32_t  ne01;
 941    int32_t  ne02;
 942    int32_t  ne03;
 943    uint64_t nb00;
 944    uint64_t nb01;
 945    uint64_t nb02;
 946    uint64_t nb03;
 947    int32_t  ne0;
 948    int32_t  ne1;
 949    int32_t  ne2;
 950    int32_t  ne3;
 951    uint64_t nb0;
 952    uint64_t nb1;
 953    uint64_t nb2;
 954    uint64_t nb3;
 955} ggml_metal_kargs_tri;
 956
 957typedef struct {
 958    int32_t  ne00;
 959    int32_t  ne01;
 960    int32_t  ne02;
 961    int32_t  ne03;
 962    uint64_t nb00;
 963    uint64_t nb01;
 964    uint64_t nb02;
 965    uint64_t nb03;
 966    int32_t  ne0;
 967    int32_t  ne1;
 968    int32_t  ne2;
 969    int32_t  ne3;
 970    int32_t  top_k;
 971} ggml_metal_kargs_argsort;
 972
 973typedef struct {
 974    int64_t  ne00;
 975    int64_t  ne01;
 976    int64_t  ne02;
 977    int64_t  ne03;
 978    uint64_t nb00;
 979    uint64_t nb01;
 980    uint64_t nb02;
 981    uint64_t nb03;
 982    int32_t  ne0;
 983    int32_t  ne1;
 984    int32_t  ne2;
 985    int32_t  ne3;
 986    int32_t  top_k;
 987    int32_t  len;
 988} ggml_metal_kargs_argsort_merge;
 989
 990typedef struct {
 991    int64_t  ne0;
 992    float    start;
 993    float    step;
 994} ggml_metal_kargs_arange;
 995
 996typedef struct {
 997    int64_t val;
 998} ggml_metal_kargs_memset;
 999
1000typedef struct {
1001    int32_t  ne00;
1002    int32_t  ne01;
1003    int32_t  ne02;
1004    int32_t  ne03;
1005    uint64_t nb00;
1006    uint64_t nb01;
1007    uint64_t nb02;
1008    uint64_t nb03;
1009    uint64_t nb10;
1010    uint64_t nb11;
1011    uint64_t nb12;
1012    uint64_t nb13;
1013} ggml_metal_kargs_count_equal;
1014
1015typedef struct {
1016    int32_t  k0;
1017    int32_t  k1;
1018    int32_t  s0;
1019    int32_t  s1;
1020    int32_t  p0;
1021    int32_t  p1;
1022    int64_t  IH;
1023    int64_t  IW;
1024    int64_t  OH;
1025    int64_t  OW;
1026    int64_t  np;
1027} ggml_metal_kargs_pool_2d;
1028
1029typedef struct {
1030    int32_t  k0;
1031    int32_t  s0;
1032    int32_t  p0;
1033    int64_t  IW;
1034    int64_t  OW;
1035    int64_t  np;
1036} ggml_metal_kargs_pool_1d;
1037
1038typedef struct {
1039     int64_t ne00;
1040    uint64_t nb01;
1041} ggml_metal_kargs_argmax;
1042
1043typedef struct {
1044    int64_t  np;
1045} ggml_metal_kargs_opt_step_adamw;
1046
1047typedef struct {
1048    int64_t  np;
1049} ggml_metal_kargs_opt_step_sgd;
1050
1051#endif // GGML_METAL_IMPL