1#define GGML_COMMON_DECL_METAL
   2#define GGML_COMMON_IMPL_METAL
   3#if defined(GGML_METAL_EMBED_LIBRARY)
   4__embed_ggml-common.h__
   5#else
   6#include "ggml-common.h"
   7#endif
   8#include "ggml-metal-impl.h"
   9
  10#include <metal_stdlib>
  11
  12#ifdef GGML_METAL_HAS_TENSOR
  13#include <metal_tensor>
  14
  15#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
  16#endif
  17
  18using namespace metal;
  19
  20#define MAX(x, y) ((x) > (y) ? (x) : (y))
  21#define MIN(x, y) ((x) < (y) ? (x) : (y))
  22#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
  23
  24#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
  25
  26#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
  27
  28#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
  29
  30// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
  31//
  32// cmd:
  33//   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal/ggml-metal.metal
  34//   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
  35//
  36#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
  37#undef GGML_METAL_HAS_BF16
  38#endif
  39
  40#if defined(GGML_METAL_HAS_BF16)
  41typedef matrix<bfloat, 4, 4> bfloat4x4;
  42typedef matrix<bfloat, 2, 4> bfloat2x4;
  43#endif
  44
  45constexpr constant static float kvalues_iq4nl_f[16] = {
  46    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
  47};
  48
  49constexpr constant static float kvalues_mxfp4_f[16] = {
  50    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
  51};
  52
  53static inline int best_index_int8(int n, constant float * val, float x) {
  54    if (x <= val[0]) return 0;
  55    if (x >= val[n-1]) return n-1;
  56    int ml = 0, mu = n-1;
  57    while (mu-ml > 1) {
  58        int mav = (ml+mu)/2;
  59        if (x < val[mav]) mu = mav; else ml = mav;
  60    }
  61    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
  62}
  63
  64static inline float e8m0_to_fp32(uint8_t x) {
  65    uint32_t bits;
  66
  67    if (x == 0) {
  68        bits = 0x00400000;
  69    } else {
  70        bits = (uint32_t) x << 23;
  71    }
  72
  73    return as_type<float>(bits);
  74}
  75
  76static inline float dot(float x, float y) {
  77    return x*y;
  78}
  79
  80// NOTE: this is not dequantizing - we are simply fitting the template
  81template <typename type4x4>
  82void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
  83    reg = (type4x4)(*src);
  84}
  85
  86template <typename type4>
  87void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
  88    reg = (type4)(*src);
  89}
  90
  91template <typename type4x4>
  92void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
  93    reg = (type4x4)(*src);
  94}
  95
  96template <typename type4>
  97void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
  98    reg = (type4)(*(src));
  99}
 100
 101#if defined(GGML_METAL_HAS_BF16)
 102template <typename type4x4>
 103void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
 104    reg = (type4x4)(*src);
 105}
 106
 107template <typename type4>
 108void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
 109    reg = (type4)(*(src));
 110}
 111#endif
 112
 113template <typename type4x4>
 114void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
 115    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
 116    const float d1 = il ? (xb->d / 16.h) : xb->d;
 117    const float d2 = d1 / 256.f;
 118    const float md = -8.h * xb->d;
 119    const ushort mask0 = il ? 0x00F0 : 0x000F;
 120    const ushort mask1 = mask0 << 8;
 121
 122    float4x4 reg_f;
 123
 124    for (int i = 0; i < 8; i++) {
 125        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
 126        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
 127    }
 128
 129    reg = (type4x4) reg_f;
 130}
 131
 132template <typename type4>
 133void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
 134    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
 135    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
 136    const float d2 = d1 / 256.f;
 137    const float md = -8.h * xb->d;
 138    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
 139    const ushort mask1 = mask0 << 8;
 140
 141    for (int i = 0; i < 2; i++) {
 142        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
 143        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
 144    }
 145}
 146
 147void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
 148#pragma METAL fp math_mode(safe)
 149    float amax = 0.0f; // absolute max
 150    float max  = 0.0f;
 151
 152    for (int j = 0; j < QK4_0; j++) {
 153        const float v = src[j];
 154        if (amax < fabs(v)) {
 155            amax = fabs(v);
 156            max  = v;
 157        }
 158    }
 159
 160    const float d = max / -8;
 161    const float id = d ? 1.0f/d : 0.0f;
 162
 163    dst.d = d;
 164
 165    for (int j = 0; j < QK4_0/2; ++j) {
 166        const float x0 = src[0       + j]*id;
 167        const float x1 = src[QK4_0/2 + j]*id;
 168
 169        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
 170        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
 171
 172        dst.qs[j]  = xi0;
 173        dst.qs[j] |= xi1 << 4;
 174    }
 175}
 176
 177void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
 178#pragma METAL fp math_mode(safe)
 179    float min = FLT_MAX;
 180    float max = -FLT_MAX;
 181
 182    for (int j = 0; j < QK4_1; j++) {
 183        const float v = src[j];
 184        if (min > v) min = v;
 185        if (max < v) max = v;
 186    }
 187
 188    const float d = (max - min) / ((1 << 4) - 1);
 189    const float id = d ? 1.0f/d : 0.0f;
 190
 191    dst.d = d;
 192    dst.m = min;
 193
 194    for (int j = 0; j < QK4_1/2; ++j) {
 195        const float x0 = (src[0       + j] - min)*id;
 196        const float x1 = (src[QK4_1/2 + j] - min)*id;
 197
 198        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
 199        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
 200
 201        dst.qs[j]  = xi0;
 202        dst.qs[j] |= xi1 << 4;
 203    }
 204}
 205
 206void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
 207#pragma METAL fp math_mode(safe)
 208    float amax = 0.0f; // absolute max
 209    float max  = 0.0f;
 210
 211    for (int j = 0; j < QK5_0; j++) {
 212        const float v = src[j];
 213        if (amax < fabs(v)) {
 214            amax = fabs(v);
 215            max  = v;
 216        }
 217    }
 218
 219    const float d = max / -16;
 220    const float id = d ? 1.0f/d : 0.0f;
 221
 222    dst.d = d;
 223
 224    uint32_t qh = 0;
 225    for (int j = 0; j < QK5_0/2; ++j) {
 226        const float x0 = src[0       + j]*id;
 227        const float x1 = src[QK5_0/2 + j]*id;
 228
 229        const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
 230        const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
 231
 232        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
 233        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
 234        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
 235    }
 236
 237    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
 238
 239    for (int j = 0; j < 4; ++j) {
 240        dst.qh[j] = qh8[j];
 241    }
 242}
 243
 244void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
 245#pragma METAL fp math_mode(safe)
 246    float max = src[0];
 247    float min = src[0];
 248
 249    for (int j = 1; j < QK5_1; j++) {
 250        const float v = src[j];
 251        min = v < min ? v : min;
 252        max = v > max ? v : max;
 253    }
 254
 255    const float d = (max - min) / 31;
 256    const float id = d ? 1.0f/d : 0.0f;
 257
 258    dst.d = d;
 259    dst.m = min;
 260
 261    uint32_t qh = 0;
 262    for (int j = 0; j < QK5_1/2; ++j) {
 263        const float x0 = (src[0       + j] - min)*id;
 264        const float x1 = (src[QK5_1/2 + j] - min)*id;
 265
 266        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
 267        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
 268
 269        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
 270        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
 271        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
 272    }
 273
 274    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
 275
 276    for (int j = 0; j < 4; ++j) {
 277        dst.qh[j] = qh8[j];
 278    }
 279}
 280
 281void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
 282#pragma METAL fp math_mode(safe)
 283    float amax = 0.0f; // absolute max
 284
 285    for (int j = 0; j < QK8_0; j++) {
 286        const float v = src[j];
 287        amax = MAX(amax, fabs(v));
 288    }
 289
 290    const float d = amax / ((1 << 7) - 1);
 291    const float id = d ? 1.0f/d : 0.0f;
 292
 293    dst.d = d;
 294
 295    for (int j = 0; j < QK8_0; ++j) {
 296        const float x0 = src[j]*id;
 297
 298        dst.qs[j] = round(x0);
 299    }
 300}
 301
 302void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
 303#pragma METAL fp math_mode(safe)
 304    float amax = 0.0f; // absolute max
 305    float max  = 0.0f;
 306
 307    for (int j = 0; j < QK4_NL; j++) {
 308        const float v = src[j];
 309        if (amax < fabs(v)) {
 310            amax = fabs(v);
 311            max  = v;
 312        }
 313    }
 314
 315    const float d = max / kvalues_iq4nl_f[0];
 316    const float id = d ? 1.0f/d : 0.0f;
 317
 318    float sumqx = 0, sumq2 = 0;
 319    for (int j = 0; j < QK4_NL/2; ++j) {
 320        const float x0 = src[0        + j]*id;
 321        const float x1 = src[QK4_NL/2 + j]*id;
 322
 323        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
 324        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
 325
 326        dst.qs[j] = xi0 | (xi1 << 4);
 327
 328        const float v0 = kvalues_iq4nl_f[xi0];
 329        const float v1 = kvalues_iq4nl_f[xi1];
 330        const float w0 = src[0        + j]*src[0        + j];
 331        const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
 332        sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
 333        sumq2 += w0*v0*v0 + w1*v1*v1;
 334
 335    }
 336
 337    dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
 338}
 339
 340template <typename type4x4>
 341void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
 342    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
 343    const float d1 = il ? (xb->d / 16.h) : xb->d;
 344    const float d2 = d1 / 256.f;
 345    const float  m = xb->m;
 346    const ushort mask0 = il ? 0x00F0 : 0x000F;
 347    const ushort mask1 = mask0 << 8;
 348
 349    float4x4 reg_f;
 350
 351    for (int i = 0; i < 8; i++) {
 352        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
 353        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
 354    }
 355
 356    reg = (type4x4) reg_f;
 357}
 358
 359template <typename type4>
 360void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
 361    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
 362    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
 363    const float d2 = d1 / 256.f;
 364    const float  m = xb->m;
 365    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
 366    const ushort mask1 = mask0 << 8;
 367
 368    for (int i = 0; i < 2; i++) {
 369        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
 370        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
 371    }
 372}
 373
 374template <typename type4x4>
 375void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
 376    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
 377    const float d = xb->d;
 378    const float md = -16.h * xb->d;
 379    const ushort mask = il ? 0x00F0 : 0x000F;
 380
 381    const uint32_t qh = *((device const uint32_t *)xb->qh);
 382
 383    const int x_mv = il ? 4 : 0;
 384
 385    const int gh_mv = il ? 12 : 0;
 386    const int gh_bk = il ?  0 : 4;
 387
 388    float4x4 reg_f;
 389
 390    for (int i = 0; i < 8; i++) {
 391        // extract the 5-th bits for x0 and x1
 392        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
 393        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 394
 395        // combine the 4-bits from qs with the 5th bit
 396        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
 397        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 398
 399        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
 400        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
 401    }
 402
 403    reg = (type4x4) reg_f;
 404}
 405
 406template <typename type4>
 407void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
 408    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
 409    const float d = xb->d;
 410    const float md = -16.h * xb->d;
 411    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
 412
 413    const uint32_t qh = *((device const uint32_t *)xb->qh);
 414
 415    const int x_mv = (il/4) ? 4 : 0;
 416
 417    const int gh_mv = (il/4) ? 12 : 0;
 418    const int gh_bk = (il/4) ?  0 : 4;
 419
 420    for (int ii = 0; ii < 2; ii++) {
 421        int i = 2*(il%4) + ii;
 422
 423        // extract the 5-th bits for x0 and x1
 424        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
 425        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 426
 427        // combine the 4-bits from qs with the 5th bit
 428        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
 429        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 430
 431        reg[2*ii + 0] = d * x0 + md;
 432        reg[2*ii + 1] = d * x1 + md;
 433    }
 434}
 435
 436template <typename type4x4>
 437void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
 438    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
 439    const float d = xb->d;
 440    const float m = xb->m;
 441    const ushort mask = il ? 0x00F0 : 0x000F;
 442
 443    const uint32_t qh = *((device const uint32_t *)xb->qh);
 444
 445    const int x_mv = il ? 4 : 0;
 446
 447    const int gh_mv = il ? 12 : 0;
 448    const int gh_bk = il ?  0 : 4;
 449
 450    float4x4 reg_f;
 451
 452    for (int i = 0; i < 8; i++) {
 453        // extract the 5-th bits for x0 and x1
 454        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
 455        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 456
 457        // combine the 4-bits from qs with the 5th bit
 458        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
 459        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 460
 461        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
 462        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
 463    }
 464
 465    reg = (type4x4) reg_f;
 466}
 467
 468template <typename type4>
 469void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
 470    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
 471    const float d = xb->d;
 472    const float m = xb->m;
 473    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
 474
 475    const uint32_t qh = *((device const uint32_t *)xb->qh);
 476
 477    const int x_mv = (il/4) ? 4 : 0;
 478
 479    const int gh_mv = (il/4) ? 12 : 0;
 480    const int gh_bk = (il/4) ?  0 : 4;
 481
 482    for (int ii = 0; ii < 2; ii++) {
 483        int i = 2*(il%4) + ii;
 484
 485        // extract the 5-th bits for x0 and x1
 486        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
 487        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 488
 489        // combine the 4-bits from qs with the 5th bit
 490        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
 491        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
 492
 493        reg[2*ii + 0] = d * x0 + m;
 494        reg[2*ii + 1] = d * x1 + m;
 495    }
 496}
 497
 498template <typename type4x4>
 499void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
 500    device const int8_t * qs = ((device const int8_t *)xb->qs);
 501    const float d = xb->d;
 502
 503    float4x4 reg_f;
 504
 505    for (int i = 0; i < 16; i++) {
 506        reg_f[i/4][i%4] = (qs[i + 16*il] * d);
 507    }
 508
 509    reg = (type4x4) reg_f;
 510}
 511
 512template <typename type4>
 513void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
 514    device const int8_t * qs = ((device const int8_t *)xb->qs);
 515    const float d = xb->d;
 516
 517    for (int i = 0; i < 4; i++) {
 518        reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
 519    }
 520}
 521
 522template <typename type4x4>
 523void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
 524    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 525
 526    const float d = e8m0_to_fp32(xb->e);
 527    const uint8_t shr = il >= 1 ? 4 : 0;
 528
 529    for (int i = 0; i < 4; ++i) {
 530        reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
 531        reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
 532        reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
 533        reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
 534    }
 535}
 536
 537template <typename type4>
 538void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
 539    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 540
 541    const float d = e8m0_to_fp32(xb->e);
 542    const short il4 = il%4;
 543
 544    const uint8_t shr = il >= 4 ? 4 : 0;
 545
 546    reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
 547    reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
 548    reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
 549    reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
 550}
 551
 552template <typename type4x4>
 553void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
 554    const float d = xb->d;
 555    const float min = xb->dmin;
 556    device const uint8_t * q = (device const uint8_t *)xb->qs;
 557    float dl, ml;
 558    uint8_t sc = xb->scales[il];
 559
 560    q = q + 32*(il/8) + 16*(il&1);
 561    il = (il/2)%4;
 562
 563    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
 564    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
 565    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
 566    for (int i = 0; i < 16; ++i) {
 567        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
 568    }
 569}
 570
 571template <typename type4x4>
 572void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
 573    const half d_all = xb->d;
 574    device const uint8_t * q = (device const uint8_t *)xb->qs;
 575    device const uint8_t * h = (device const uint8_t *)xb->hmask;
 576    device const int8_t * scales = (device const int8_t *)xb->scales;
 577
 578    q = q + 32 * (il/8) + 16 * (il&1);
 579    h = h + 16 * (il&1);
 580    uint8_t m = 1 << (il/2);
 581    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
 582                                 ((il/4)>0 ? 12  : 3);
 583    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
 584    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
 585    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
 586                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
 587    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
 588    const float ml = 4.f * dl;
 589
 590    il = (il/2) & 3;
 591    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
 592    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
 593    dl *= coef;
 594
 595    for (int i = 0; i < 16; ++i) {
 596        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
 597    }
 598}
 599
 600static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
 601    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
 602                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
 603}
 604
 605template <typename type4x4>
 606void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
 607    device const uchar * q = xb->qs;
 608
 609    short is = (il/4) * 2;
 610    q = q + (il/4) * 32 + 16 * (il&1);
 611    il = il & 3;
 612    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
 613    const float d   = il < 2 ? xb->d : xb->d / 16.h;
 614    const float min = xb->dmin;
 615    const float dl = d * sc[0];
 616    const float ml = min * sc[1];
 617
 618    const ushort mask = il < 2 ? 0x0F : 0xF0;
 619    for (int i = 0; i < 16; ++i) {
 620        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
 621    }
 622}
 623
 624template <typename type4x4>
 625void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
 626    device const uint8_t * q  = xb->qs;
 627    device const uint8_t * qh = xb->qh;
 628
 629    short is = (il/4) * 2;
 630    q  = q + 32 * (il/4) + 16 * (il&1);
 631    qh = qh + 16 * (il&1);
 632    uint8_t ul = 1 << (il/2);
 633    il = il & 3;
 634    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
 635    const float d = il < 2 ? xb->d : xb->d / 16.f;
 636    const float min = xb->dmin;
 637    const float dl = d * sc[0];
 638    const float ml = min * sc[1];
 639
 640    const ushort mask  = il<2 ? 0x0F : 0xF0;
 641    const float qh_val = il<2 ? 16.f : 256.f;
 642    for (int i = 0; i < 16; ++i) {
 643        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
 644    }
 645}
 646
 647template <typename type4x4>
 648void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
 649    const half d_all = xb->d;
 650    device const uint16_t * ql = (device const uint16_t *)xb->ql;
 651    device const uint16_t * qh = (device const uint16_t *)xb->qh;
 652    device const int8_t * scales = (device const int8_t *)xb->scales;
 653
 654    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
 655    qh = qh + 16*(il/8) + 8*(il&1);
 656    float sc = scales[(il%2) + 2 * ((il/2))];
 657    il = (il/2) & 3;
 658
 659    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
 660    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
 661    const float ml = d_all * sc * 32.f;
 662    const float dl0 = d_all * sc;
 663    const float dl1 = dl0 / 256.f;
 664    const float dl2 = dl0 / (256.f * 256.f);
 665    const float dl3 = dl0 / (256.f * 256.f * 256.f);
 666    const uint8_t shr_h = il>2 ? 2 : 0;
 667    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
 668    const uint8_t shr_l = il>1 ? 4 : 0;
 669    for (int i = 0; i < 4; ++i) {
 670        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
 671        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
 672        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
 673        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
 674        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
 675        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
 676        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
 677    }
 678}
 679
 680template <typename type4x4>
 681void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
 682    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 683    const float d = xb->d;
 684    const int ib32 = il/2;
 685    il = il%2;
 686    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 687    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
 688    device const uint16_t * q2 = xb->qs + 4*ib32;
 689    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
 690    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
 691    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
 692    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
 693    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
 694    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
 695    for (int i = 0; i < 8; ++i) {
 696        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
 697    }
 698    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
 699    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
 700    for (int i = 0; i < 8; ++i) {
 701        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
 702    }
 703}
 704
 705template <typename type4x4>
 706void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
 707    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 708    const float d = xb->d;
 709    const int ib32 = il/2;
 710    il = il%2;
 711    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 712    device const uint16_t * q2 = xb->qs + 4*ib32;
 713    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
 714    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
 715    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
 716    for (int i = 0; i < 8; ++i) {
 717        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
 718    }
 719    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
 720    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
 721    for (int i = 0; i < 8; ++i) {
 722        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
 723    }
 724}
 725
 726template <typename type4x4>
 727void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
 728    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 729    const float d = xb->d;
 730    const int ib32 = il/2;
 731    il = il%2;
 732    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 733    device const uint8_t * q3 = xb->qs + 8*ib32;
 734    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
 735    const uint32_t aux32 = gas[0] | (gas[1] << 16);
 736    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
 737    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
 738    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
 739    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
 740    for (int i = 0; i < 4; ++i) {
 741        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
 742        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
 743    }
 744    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
 745    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
 746    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
 747    for (int i = 0; i < 4; ++i) {
 748        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
 749        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
 750    }
 751}
 752
 753template <typename type4x4>
 754void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
 755    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 756    const float d = xb->d;
 757    const int ib32 = il/2;
 758    il = il%2;
 759    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 760    device const uint8_t * qs = xb->qs + 8*ib32;
 761    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
 762    const uint8_t qh = xb->qh[ib32] >> 4*il;
 763    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
 764    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
 765    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
 766    for (int i = 0; i < 4; ++i) {
 767        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
 768        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
 769    }
 770    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
 771    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
 772    for (int i = 0; i < 4; ++i) {
 773        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
 774        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
 775    }
 776}
 777
 778template <typename type4x4>
 779void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
 780    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 781    const float d = xb->d;
 782    const int ib32 = il/2;
 783    il = il%2;
 784    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 785    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
 786    device const uint8_t * signs = qs + QK_K/8;
 787    const uint8_t qh = xb->qh[ib32] >> 4*il;
 788    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
 789    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
 790    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
 791    for (int i = 0; i < 8; ++i) {
 792        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
 793        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
 794    }
 795}
 796
 797template <typename type4x4>
 798void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
 799    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 800    const int ib32 = il/2;
 801    il = il%2;
 802    const float d = xb->d;
 803    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
 804    device const uint16_t * qh = xb->qh;
 805    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
 806    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
 807    const uint16_t h = qh[ib32] >> 6*il;
 808    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
 809    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
 810    for (int i = 0; i < 4; ++i) {
 811        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
 812        reg[1][i] = dl * (grid1[i] >>  4) + ml;
 813        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
 814        reg[3][i] = dl * (grid2[i] >>  4) + ml;
 815    }
 816}
 817
 818template <typename type4x4>
 819void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
 820    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 821    const int ib32 = il/2;
 822    il = il%2;
 823    device const uint16_t * sc = (device const uint16_t *)xb->scales;
 824
 825    iq1m_scale_t scale;
 826    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
 827    const float d = scale.f16;
 828
 829    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
 830    device const uint8_t * qh = xb->qh + 2*ib32 + il;
 831
 832    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
 833    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
 834    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
 835    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
 836    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
 837    for (int i = 0; i < 4; ++i) {
 838        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
 839        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
 840        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
 841        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
 842    }
 843}
 844
 845template <typename type4x4>
 846void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
 847    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
 848    const float d = xb->d;
 849    uint32_t aux32;
 850    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
 851    for (int i = 0; i < 4; ++i) {
 852        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
 853        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
 854        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
 855        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
 856        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
 857    }
 858}
 859
 860template <typename type4>
 861void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
 862    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
 863    const float d = xb->d;
 864    uint32_t aux32;
 865    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
 866    aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
 867    reg[0] = d * kvalues_iq4nl_f[q8[0]];
 868    reg[1] = d * kvalues_iq4nl_f[q8[1]];
 869    reg[2] = d * kvalues_iq4nl_f[q8[2]];
 870    reg[3] = d * kvalues_iq4nl_f[q8[3]];
 871}
 872
 873template <typename type4x4>
 874void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
 875    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 876    const int ib32 = il/2;
 877    il = il%2;
 878    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
 879    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
 880    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
 881    const float d = (float)xb->d * (ls - 32);
 882    uint32_t aux32;
 883    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
 884    for (int i = 0; i < 4; ++i) {
 885        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
 886        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
 887        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
 888        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
 889        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
 890    }
 891}
 892
 893enum ggml_sort_order {
 894    GGML_SORT_ORDER_ASC,
 895    GGML_SORT_ORDER_DESC,
 896};
 897
 898constant float GELU_COEF_A     = 0.044715f;
 899constant float GELU_QUICK_COEF = -1.702f;
 900constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 901constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
 902
 903// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
 904// ref: https://www.johndcook.com/blog/python_erf/
 905constant float p_erf  = 0.3275911f;
 906constant float a1_erf = 0.254829592f;
 907constant float a2_erf = -0.284496736f;
 908constant float a3_erf = 1.421413741f;
 909constant float a4_erf = -1.453152027f;
 910constant float a5_erf = 1.061405429f;
 911
 912template<typename T>
 913inline T erf_approx(T x) {
 914    T sign_x = sign(x);
 915    x = fabs(x);
 916    T t = 1.0f / (1.0f + p_erf * x);
 917    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
 918    return sign_x * y;
 919}
 920
 921template<typename T> T elu_approx(T x);
 922
 923template<> inline float elu_approx<float>(float x) {
 924    return (x > 0.f) ? x : (exp(x) - 1);
 925}
 926
 927template<> inline float4 elu_approx<float4>(float4 x) {
 928    float4 res;
 929
 930    res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
 931    res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
 932    res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
 933    res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
 934
 935    return res;
 936}
 937
 938constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
 939constant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
 940
 941template <typename T0, typename T, typename TC>
 942kernel void kernel_unary_impl(
 943        constant ggml_metal_kargs_unary & args,
 944        device const char * src0,
 945        device       char * dst,
 946        uint3   tgpig[[threadgroup_position_in_grid]],
 947        ushort3 tpitg[[thread_position_in_threadgroup]],
 948        ushort3   ntg[[threads_per_threadgroup]]) {
 949#define FC_OP  FC_unary_op
 950#define FC_CNT FC_unary_cnt
 951
 952    device const T0 * src0_ptr;
 953    device       T  * dst_ptr;
 954
 955    int i0;
 956
 957    if (FC_CNT) {
 958        i0 = tgpig.x;
 959
 960        src0_ptr = (device const T0 *) (src0);
 961        dst_ptr  = (device       T  *) (dst);
 962    } else {
 963        const int i03 = tgpig.z;
 964        const int i02 = tgpig.y;
 965        const int k0  = tgpig.x/args.ne01;
 966        const int i01 = tgpig.x - k0*args.ne01;
 967
 968        i0 = k0*ntg.x + tpitg.x;
 969
 970        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
 971        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );
 972    }
 973
 974    {
 975        //threadgroup_barrier(mem_flags::mem_none);
 976
 977        if (!FC_CNT) {
 978            if (i0 >= args.ne0) {
 979                return;
 980            }
 981        }
 982
 983        const TC x = (TC) src0_ptr[i0];
 984
 985        if (FC_OP == OP_UNARY_NUM_SCALE) {
 986            dst_ptr[i0] = (T) (args.scale * x + args.bias);
 987        }
 988
 989        if (FC_OP == OP_UNARY_NUM_FILL) {
 990            dst_ptr[i0] = (T) args.val;
 991        }
 992
 993        if (FC_OP == OP_UNARY_NUM_CLAMP) {
 994            dst_ptr[i0] = (T) clamp(x, args.min, args.max);
 995        }
 996
 997        if (FC_OP == OP_UNARY_NUM_SQR) {
 998            dst_ptr[i0] = (T) (x * x);
 999        }
1000
1001        if (FC_OP == OP_UNARY_NUM_SQRT) {
1002            dst_ptr[i0] = (T) sqrt(x);
1003        }
1004
1005        if (FC_OP == OP_UNARY_NUM_SIN) {
1006            dst_ptr[i0] = (T) sin(x);
1007        }
1008
1009        if (FC_OP == OP_UNARY_NUM_COS) {
1010            dst_ptr[i0] = (T) cos(x);
1011        }
1012
1013        if (FC_OP == OP_UNARY_NUM_LOG) {
1014            dst_ptr[i0] = (T) log(x);
1015        }
1016
1017        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
1018            dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
1019        }
1020
1021        if (FC_OP == OP_UNARY_NUM_TANH) {
1022            dst_ptr[i0] = (T) precise::tanh(x);
1023        }
1024
1025        if (FC_OP == OP_UNARY_NUM_RELU) {
1026            dst_ptr[i0] = (T) fmax(0, x);
1027        }
1028
1029        if (FC_OP == OP_UNARY_NUM_SIGMOID) {
1030            dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
1031        }
1032
1033        if (FC_OP == OP_UNARY_NUM_GELU) {
1034            dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
1035        }
1036
1037        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
1038            dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
1039        }
1040
1041        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
1042            dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
1043        }
1044
1045        if (FC_OP == OP_UNARY_NUM_SILU) {
1046            dst_ptr[i0] = (T) (x / (1 + exp(-x)));
1047        }
1048
1049        if (FC_OP == OP_UNARY_NUM_ELU) {
1050            dst_ptr[i0] = (T) elu_approx(x);
1051        }
1052
1053        if (FC_OP == OP_UNARY_NUM_NEG) {
1054            dst_ptr[i0] = (T) -x;
1055        }
1056
1057        if (FC_OP == OP_UNARY_NUM_ABS) {
1058            dst_ptr[i0] = (T) fabs(x);
1059        }
1060
1061        if (FC_OP == OP_UNARY_NUM_SGN) {
1062            dst_ptr[i0] = T(x > 0) - T(x < 0);
1063        }
1064
1065        if (FC_OP == OP_UNARY_NUM_STEP) {
1066            dst_ptr[i0] = T(x > 0);
1067        }
1068
1069        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
1070            dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
1071        }
1072
1073        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
1074            dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
1075        }
1076
1077        if (FC_OP == OP_UNARY_NUM_EXP) {
1078            dst_ptr[i0] = (T) exp(x);
1079        }
1080
1081        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
1082            dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
1083        }
1084
1085        if (FC_OP == OP_UNARY_NUM_EXPM1) {
1086            // TODO: precise implementation
1087            dst_ptr[i0] = (T) (exp(x) - 1);
1088        }
1089    }
1090
1091#undef FC_OP
1092#undef FC_CNT
1093}
1094
1095typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
1096
1097template [[host_name("kernel_unary_f32_f32")]]   kernel kernel_unary_t kernel_unary_impl<float,  float,  float>;
1098template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
1099template [[host_name("kernel_unary_f16_f16")]]   kernel kernel_unary_t kernel_unary_impl<half,   half,   float>;
1100template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4,  half4,  float4>;
1101
1102// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
1103constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
1104constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
1105constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
1106
1107template <typename T0, typename T1, typename T>
1108kernel void kernel_bin_fuse_impl(
1109        constant ggml_metal_kargs_bin & args,
1110        device const char * src0,
1111        device const char * src1,
1112        device       char * dst,
1113        uint3   tgpig[[threadgroup_position_in_grid]],
1114        ushort3 tpitg[[thread_position_in_threadgroup]],
1115        ushort3   ntg[[threads_per_threadgroup]]) {
1116#define FC_OP FC_bin_op
1117#define FC_F  FC_bin_f
1118#define FC_RB FC_bin_rb
1119
1120    if (FC_RB) {
1121        // row broadcast
1122        const uint i0 = tgpig.x;
1123        const uint i1 = i0%args.ne10;
1124
1125        device const T0 * src0_row = (device const T0 *) (src0);
1126        device       T  * dst_row  = (device       T  *) (dst);
1127
1128        if (FC_F == 1) {
1129            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
1130
1131            if (FC_OP == 0) {
1132                dst_row[i0] = src0_row[i0] + src1_row[i1];
1133            }
1134
1135            if (FC_OP == 1) {
1136                dst_row[i0] = src0_row[i0] - src1_row[i1];
1137            }
1138
1139            if (FC_OP == 2) {
1140                dst_row[i0] = src0_row[i0] * src1_row[i1];
1141            }
1142
1143            if (FC_OP == 3) {
1144                dst_row[i0] = src0_row[i0] / src1_row[i1];
1145            }
1146        } else {
1147            T0 res = src0_row[i0];
1148
1149            if (FC_OP == 0) {
1150                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1151                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];
1152                }
1153            }
1154
1155            if (FC_OP == 1) {
1156                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1157                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
1158                }
1159            }
1160
1161            if (FC_OP == 2) {
1162                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1163                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
1164                }
1165            }
1166
1167            if (FC_OP == 3) {
1168                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1169                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
1170                }
1171            }
1172
1173            dst_row[i0] = res;
1174        }
1175    } else {
1176        const int i03 = tgpig.z;
1177        const int i02 = tgpig.y;
1178        const int i01 = tgpig.x;
1179
1180        if (i01 >= args.ne01) {
1181            return;
1182        }
1183
1184        const int i13 = i03%args.ne13;
1185        const int i12 = i02%args.ne12;
1186        const int i11 = i01%args.ne11;
1187
1188        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
1189        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
1190
1191        if (FC_F == 1) {
1192            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1193
1194            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1195                const int i10 = i0%args.ne10;
1196
1197                if (FC_OP == 0) {
1198                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
1199                }
1200
1201                if (FC_OP == 1) {
1202                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
1203                }
1204
1205                if (FC_OP == 2) {
1206                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
1207                }
1208
1209                if (FC_OP == 3) {
1210                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
1211                }
1212            }
1213        } else {
1214            device const T1 * src1_ptr[8];
1215            FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1216                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1217            }
1218
1219            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1220                const int i10 = i0%args.ne10;
1221
1222                T res = src0_ptr[i0];
1223
1224                if (FC_OP == 0) {
1225                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1226                        res += src1_ptr[j][i10];
1227                    }
1228                }
1229
1230                if (FC_OP == 1) {
1231                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1232                        res -= src1_ptr[j][i10];
1233                    }
1234                }
1235
1236                if (FC_OP == 2) {
1237                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1238                        res *= src1_ptr[j][i10];
1239                    }
1240                }
1241
1242                if (FC_OP == 3) {
1243                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1244                        res /= src1_ptr[j][i10];
1245                    }
1246                }
1247
1248                dst_ptr[i0] = res;
1249            }
1250        }
1251    }
1252
1253#undef FC_OP
1254#undef FC_F
1255#undef FC_RB
1256}
1257
1258typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
1259
1260template [[host_name("kernel_bin_fuse_f32_f32_f32")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float,  float,  float>;
1261template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
1262
1263kernel void kernel_add_id(
1264        constant ggml_metal_kargs_add_id & args,
1265        device const char * src0,
1266        device const char * src1,
1267        device const char * src2,
1268        device       char * dst,
1269        uint3   tgpig[[threadgroup_position_in_grid]],
1270        ushort3 tpitg[[thread_position_in_threadgroup]],
1271        ushort3   ntg[[threads_per_threadgroup]]) {
1272    const int i1 = tgpig.x;
1273    const int i2 = tgpig.y;
1274
1275    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1276
1277    const size_t nb1 = args.ne0 * sizeof(float);
1278    const size_t nb2 = args.ne1 * nb1;
1279
1280    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);
1281    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
1282    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1283
1284    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1285        dst_row[i0] = src0_row[i0] + src1_row[i0];
1286    }
1287}
1288
1289template<typename T>
1290kernel void kernel_repeat(
1291        constant ggml_metal_kargs_repeat & args,
1292        device const char * src0,
1293        device       char * dst,
1294        uint3   tgpig[[threadgroup_position_in_grid]],
1295        ushort3 tpitg[[thread_position_in_threadgroup]],
1296        ushort3   ntg[[threads_per_threadgroup]]) {
1297    const int i3 = tgpig.z;
1298    const int i2 = tgpig.y;
1299    const int i1 = tgpig.x;
1300
1301    const int i03 = i3%args.ne03;
1302    const int i02 = i2%args.ne02;
1303    const int i01 = i1%args.ne01;
1304
1305    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1306    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
1307
1308    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1309        const int i00 = i0%args.ne00;
1310        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1311    }
1312}
1313
1314typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1315
1316template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1317template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1318template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1319template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1320
1321kernel void kernel_reglu_f32(
1322        constant ggml_metal_kargs_glu & args,
1323        device const char * src0,
1324        device const char * src1,
1325        device       char * dst,
1326        uint tgpig[[threadgroup_position_in_grid]],
1327        uint tpitg[[thread_position_in_threadgroup]],
1328        uint   ntg[[threads_per_threadgroup]]) {
1329    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1330    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1331    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1332
1333    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1334        const float x0 = src0_row[i0];
1335        const float x1 = src1_row[i0];
1336
1337        dst_row[i0] = x0*x1*(x0 > 0.0f);
1338    }
1339}
1340
1341kernel void kernel_geglu_f32(
1342        constant ggml_metal_kargs_glu & args,
1343        device const char * src0,
1344        device const char * src1,
1345        device       char * dst,
1346        uint tgpig[[threadgroup_position_in_grid]],
1347        uint tpitg[[thread_position_in_threadgroup]],
1348        uint   ntg[[threads_per_threadgroup]]) {
1349    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1350    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1351    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1352
1353    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1354        const float x0 = src0_row[i0];
1355        const float x1 = src1_row[i0];
1356
1357        const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1358
1359        dst_row[i0] = gelu*x1;
1360    }
1361}
1362
1363kernel void kernel_swiglu_f32(
1364        constant ggml_metal_kargs_glu & args,
1365        device const char * src0,
1366        device const char * src1,
1367        device       char * dst,
1368        uint tgpig[[threadgroup_position_in_grid]],
1369        uint tpitg[[thread_position_in_threadgroup]],
1370        uint   ntg[[threads_per_threadgroup]]) {
1371    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1372    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1373    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1374
1375    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1376        const float x0 = src0_row[i0];
1377        const float x1 = src1_row[i0];
1378
1379        const float silu = x0 / (1.0f + exp(-x0));
1380
1381        dst_row[i0] = silu*x1;
1382    }
1383}
1384
1385kernel void kernel_swiglu_oai_f32(
1386        constant ggml_metal_kargs_glu & args,
1387        device const char * src0,
1388        device const char * src1,
1389        device       char * dst,
1390        uint tgpig[[threadgroup_position_in_grid]],
1391        uint tpitg[[thread_position_in_threadgroup]],
1392        uint   ntg[[threads_per_threadgroup]]) {
1393    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1394    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1395    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1396
1397    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1398        float x0 = src0_row[i0];
1399        float x1 = src1_row[i0];
1400
1401        x0 = min(x0, args.limit);
1402        x1 = max(min(x1, args.limit), -args.limit);
1403
1404        float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
1405        out_glu = out_glu * (1.0f + x1);
1406
1407        dst_row[i0] = out_glu;
1408    }
1409}
1410
1411kernel void kernel_geglu_erf_f32(
1412        constant ggml_metal_kargs_glu & args,
1413        device const char * src0,
1414        device const char * src1,
1415        device       char * dst,
1416        uint tgpig[[threadgroup_position_in_grid]],
1417        uint tpitg[[thread_position_in_threadgroup]],
1418        uint   ntg[[threads_per_threadgroup]]) {
1419    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1420    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1421    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1422
1423    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1424        const float x0 = src0_row[i0];
1425        const float x1 = src1_row[i0];
1426
1427        const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1428
1429        dst_row[i0] = gelu_erf*x1;
1430    }
1431}
1432
1433kernel void kernel_geglu_quick_f32(
1434        constant ggml_metal_kargs_glu & args,
1435        device const char * src0,
1436        device const char * src1,
1437        device       char * dst,
1438        uint tgpig[[threadgroup_position_in_grid]],
1439        uint tpitg[[thread_position_in_threadgroup]],
1440        uint   ntg[[threads_per_threadgroup]]) {
1441    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1442    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1443    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
1444
1445    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1446        const float x0 = src0_row[i0];
1447        const float x1 = src1_row[i0];
1448
1449        const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1450
1451        dst_row[i0] = gelu_quick*x1;
1452    }
1453}
1454
1455kernel void kernel_op_sum_f32(
1456        constant ggml_metal_kargs_sum & args,
1457        device const float * src0,
1458        device       float * dst,
1459        threadgroup  float * shmem_f32 [[threadgroup(0)]],
1460        uint3   tgpig[[threadgroup_position_in_grid]],
1461        ushort3 tpitg[[thread_position_in_threadgroup]],
1462        ushort  sgitg[[simdgroup_index_in_threadgroup]],
1463        ushort  tiisg[[thread_index_in_simdgroup]],
1464        ushort3   ntg[[threads_per_threadgroup]]) {
1465
1466    if (args.np == 0) {
1467        return;
1468    }
1469
1470    // TODO: become function constant
1471    const uint nsg = (ntg.x + 31) / 32;
1472
1473    float sumf = 0;
1474
1475    for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1476        sumf += src0[i0];
1477    }
1478
1479    sumf = simd_sum(sumf);
1480
1481    if (tiisg == 0) {
1482        shmem_f32[sgitg] = sumf;
1483    }
1484
1485    threadgroup_barrier(mem_flags::mem_threadgroup);
1486
1487    float total = 0;
1488
1489    if (sgitg == 0) {
1490        float v = 0;
1491
1492        if (tpitg.x < nsg) {
1493            v = shmem_f32[tpitg.x];
1494        }
1495
1496        total = simd_sum(v);
1497
1498        if (tpitg.x == 0) {
1499            dst[0] = total;
1500        }
1501    }
1502}
1503
1504template <bool norm>
1505kernel void kernel_sum_rows(
1506        constant ggml_metal_kargs_sum_rows & args,
1507        device const float * src0,
1508        device       float * dst,
1509        threadgroup  float * shmem_f32 [[threadgroup(0)]],
1510        uint3   tgpig[[threadgroup_position_in_grid]],
1511        ushort3 tpitg[[thread_position_in_threadgroup]],
1512        ushort  sgitg[[simdgroup_index_in_threadgroup]],
1513        ushort  tiisg[[thread_index_in_simdgroup]],
1514        ushort3   ntg[[threads_per_threadgroup]]) {
1515    int64_t i3 = tgpig.z;
1516    int64_t i2 = tgpig.y;
1517    int64_t i1 = tgpig.x;
1518
1519    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1520        return;
1521    }
1522
1523    if (sgitg == 0) {
1524        shmem_f32[tiisg] = 0.0f;
1525    }
1526
1527    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1528    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
1529
1530    float sumf = 0;
1531
1532    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1533        sumf += src_row[i0];
1534    }
1535
1536    sumf = simd_sum(sumf);
1537
1538    threadgroup_barrier(mem_flags::mem_threadgroup);
1539
1540    if (tiisg == 0) {
1541        shmem_f32[sgitg] = sumf;
1542    }
1543
1544    threadgroup_barrier(mem_flags::mem_threadgroup);
1545
1546    sumf = shmem_f32[tiisg];
1547    sumf = simd_sum(sumf);
1548
1549    if (tpitg.x == 0) {
1550        dst_row[0] = norm ? sumf / args.ne00 : sumf;
1551    }
1552}
1553
1554typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1555
1556template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1557template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows<true>;
1558
1559template<typename T>
1560kernel void kernel_cumsum_blk(
1561        constant ggml_metal_kargs_cumsum_blk & args,
1562        device const char * src0,
1563        device       char * tmp,
1564        device       char * dst,
1565        threadgroup  char * shmem [[threadgroup(0)]],
1566        uint3   tgpig[[threadgroup_position_in_grid]],
1567        ushort3 tpitg[[thread_position_in_threadgroup]],
1568        ushort  sgitg[[simdgroup_index_in_threadgroup]],
1569        ushort  tiisg[[thread_index_in_simdgroup]],
1570        ushort3   ntg[[threads_per_threadgroup]]) {
1571    const int ib = tgpig[0]/args.ne01;
1572
1573    const int i00 = ib*ntg.x;
1574    const int i01 = tgpig[0]%args.ne01;
1575    const int i02 = tgpig[1];
1576    const int i03 = tgpig[2];
1577
1578    device const float * src0_row = (device const float *) (src0 +
1579            args.nb01*i01 +
1580            args.nb02*i02 +
1581            args.nb03*i03);
1582
1583    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
1584
1585    float v = 0.0f;
1586
1587    if (i00 + tpitg.x < args.ne00) {
1588        v = src0_row[i00 + tpitg.x];
1589    }
1590
1591    float s = simd_prefix_inclusive_sum(v);
1592
1593    if (tiisg == N_SIMDWIDTH - 1) {
1594        shmem_f32[sgitg] = s;
1595    }
1596
1597    threadgroup_barrier(mem_flags::mem_threadgroup);
1598
1599    if (sgitg == 0) {
1600        shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
1601    }
1602
1603    threadgroup_barrier(mem_flags::mem_threadgroup);
1604
1605    s += shmem_f32[sgitg];
1606
1607    device float * dst_row = (device float *) dst +
1608        args.ne00*i01 +
1609        args.ne00*args.ne01*i02 +
1610        args.ne00*args.ne01*args.ne02*i03;
1611
1612    if (i00 + tpitg.x < args.ne00) {
1613        dst_row[i00 + tpitg.x] = s;
1614    }
1615
1616    if (args.outb && tpitg.x == ntg.x - 1) {
1617        device float * tmp_row = (device float *) tmp +
1618            args.net0*i01 +
1619            args.net0*args.net1*i02 +
1620            args.net0*args.net1*args.net2*i03;
1621
1622        tmp_row[ib] = s;
1623    }
1624}
1625
1626typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
1627
1628template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
1629
1630template<typename T>
1631kernel void kernel_cumsum_add(
1632        constant ggml_metal_kargs_cumsum_add & args,
1633        device const char * tmp,
1634        device       char * dst,
1635        uint3   tgpig[[threadgroup_position_in_grid]],
1636        ushort3 tpitg[[thread_position_in_threadgroup]],
1637        ushort  sgitg[[simdgroup_index_in_threadgroup]],
1638        ushort  tiisg[[thread_index_in_simdgroup]],
1639        ushort3   ntg[[threads_per_threadgroup]]) {
1640    const int ib = tgpig[0]/args.ne01;
1641
1642    if (ib == 0) {
1643        return;
1644    }
1645
1646    const int i00 = ib*ntg.x;
1647    const int i01 = tgpig[0]%args.ne01;
1648    const int i02 = tgpig[1];
1649    const int i03 = tgpig[2];
1650
1651    device const float * tmp_row = (device const float *) (tmp +
1652            args.nbt1*i01 +
1653            args.nbt2*i02 +
1654            args.nbt3*i03);
1655
1656    device float * dst_row = (device float *) dst +
1657        args.ne00*i01 +
1658        args.ne00*args.ne01*i02 +
1659        args.ne00*args.ne01*args.ne02*i03;
1660
1661    if (i00 + tpitg.x < args.ne00) {
1662        dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
1663    }
1664}
1665
1666typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
1667
1668template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
1669
1670
1671template<uint32_t ttype>
1672bool _ggml_vec_tri_cmp(const int i, const int r);
1673
1674template<>
1675bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
1676    return i < r;
1677}
1678
1679template<>
1680bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
1681    return i <= r;
1682}
1683
1684template<>
1685bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
1686    return i > r;
1687}
1688
1689template<>
1690bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
1691    return i >= r;
1692}
1693
1694template<typename T, int ttype>
1695kernel void kernel_tri(
1696        constant ggml_metal_kargs_tri & args,
1697        device const char * src0,
1698        device const char * dst,
1699        uint3   tgpig[[threadgroup_position_in_grid]],
1700        ushort3 tpitg[[thread_position_in_threadgroup]],
1701        ushort3   ntg[[threads_per_threadgroup]]) {
1702    const int i3 = tgpig.z;
1703    const int i2 = tgpig.y;
1704    const int i1 = tgpig.x;
1705
1706    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1707        return;
1708    }
1709
1710    device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1711    device       T * dst_row = (device       T *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
1712
1713    // Each thread is a single element of the row if ne00 < max threads per
1714    // threadgroup, so this will loop once for each index that this thread is
1715    // responsible for
1716    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1717        // Use the comparison as a mask for branchless
1718        dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
1719    }
1720}
1721
1722typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
1723
1724template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
1725template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
1726template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
1727template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
1728template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
1729template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
1730template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
1731template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
1732#if defined(GGML_METAL_HAS_BF16)
1733template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
1734template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
1735template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
1736template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
1737#endif
1738
1739template<typename T>
1740kernel void kernel_soft_max(
1741        constant ggml_metal_kargs_soft_max & args,
1742        device const  char * src0,
1743        device const  char * src1,
1744        device const  char * src2,
1745        device        char * dst,
1746        threadgroup  float * buf [[threadgroup(0)]],
1747        uint3 tgpig[[threadgroup_position_in_grid]],
1748        uint3 tpitg[[thread_position_in_threadgroup]],
1749        uint  sgitg[[simdgroup_index_in_threadgroup]],
1750        uint  tiisg[[thread_index_in_simdgroup]],
1751        uint3  tptg[[threads_per_threadgroup]]) {
1752    const int32_t i03 = tgpig.z;
1753    const int32_t i02 = tgpig.y;
1754    const int32_t i01 = tgpig.x;
1755
1756    const int32_t i13 = i03%args.ne13;
1757    const int32_t i12 = i02%args.ne12;
1758    const int32_t i11 = i01;
1759
1760    device const float * psrc0 =                (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1761    device const     T * pmask = src1 != src0 ? (device const T *    ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1762    device const float * psrc2 = src2 != src0 ? (device const float *) (src2)                                                 : nullptr;
1763    device       float * pdst  =                (device       float *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
1764
1765    float slope = 1.0f;
1766
1767    // ALiBi
1768    if (args.max_bias > 0.0f) {
1769        const int32_t h = i02;
1770
1771        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1772        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1773
1774        slope = pow(base, exp);
1775    }
1776
1777    // parallel max
1778    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
1779
1780    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1781        lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1782    }
1783
1784    // find the max value in the block
1785    float max_val = simd_max(lmax);
1786    if (tptg.x > N_SIMDWIDTH) {
1787        if (sgitg == 0) {
1788            buf[tiisg] = -INFINITY;
1789        }
1790
1791        threadgroup_barrier(mem_flags::mem_threadgroup);
1792
1793        if (tiisg == 0) {
1794            buf[sgitg] = max_val;
1795        }
1796
1797        threadgroup_barrier(mem_flags::mem_threadgroup);
1798
1799        max_val = buf[tiisg];
1800        max_val = simd_max(max_val);
1801    }
1802
1803    // parallel sum
1804    float lsum = 0.0f;
1805    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1806        const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1807        lsum += exp_psrc0;
1808        pdst[i00] = exp_psrc0;
1809    }
1810
1811    // This barrier fixes a failing test
1812    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1813    threadgroup_barrier(mem_flags::mem_none);
1814
1815    float sum = simd_sum(lsum);
1816
1817    if (tptg.x > N_SIMDWIDTH) {
1818        if (sgitg == 0) {
1819            buf[tiisg] = 0.0f;
1820        }
1821
1822        threadgroup_barrier(mem_flags::mem_threadgroup);
1823
1824        if (tiisg == 0) {
1825            buf[sgitg] = sum;
1826        }
1827
1828        threadgroup_barrier(mem_flags::mem_threadgroup);
1829
1830        sum = buf[tiisg];
1831        sum = simd_sum(sum);
1832    }
1833
1834    if (psrc2) {
1835        sum += exp(psrc2[i02] - max_val);
1836    }
1837
1838    const float inv_sum = 1.0f/sum;
1839
1840    for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1841        pdst[i00] *= inv_sum;
1842    }
1843}
1844
1845template<typename T>
1846kernel void kernel_soft_max_4(
1847        constant ggml_metal_kargs_soft_max & args,
1848        device const  char * src0,
1849        device const  char * src1,
1850        device const  char * src2,
1851        device        char * dst,
1852        threadgroup  float * buf [[threadgroup(0)]],
1853        uint3 tgpig[[threadgroup_position_in_grid]],
1854        uint3 tpitg[[thread_position_in_threadgroup]],
1855        uint  sgitg[[simdgroup_index_in_threadgroup]],
1856        uint  tiisg[[thread_index_in_simdgroup]],
1857        uint3  tptg[[threads_per_threadgroup]]) {
1858    const int32_t i03 = tgpig.z;
1859    const int32_t i02 = tgpig.y;
1860    const int32_t i01 = tgpig.x;
1861
1862    const int32_t i13 = i03%args.ne13;
1863    const int32_t i12 = i02%args.ne12;
1864    const int32_t i11 = i01;
1865
1866    device const float4 * psrc4 =                (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1867    device const      T * pmask = src1 != src0 ? (device const T *     ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1868    device const float *  psrc2 = src2 != src0 ? (device const float * ) (src2)                                                 : nullptr;
1869    device       float4 * pdst4 =                (device       float4 *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
1870
1871    float slope = 1.0f;
1872
1873    if (args.max_bias > 0.0f) {
1874        const int32_t h = i02;
1875
1876        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1877        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1878
1879        slope = pow(base, exp);
1880    }
1881
1882    // parallel max
1883    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
1884
1885    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1886        lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1887    }
1888
1889    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1890
1891    float max_val = simd_max(lmax);
1892    if (tptg.x > N_SIMDWIDTH) {
1893        if (sgitg == 0) {
1894            buf[tiisg] = -INFINITY;
1895        }
1896
1897        threadgroup_barrier(mem_flags::mem_threadgroup);
1898
1899        if (tiisg == 0) {
1900            buf[sgitg] = max_val;
1901        }
1902
1903        threadgroup_barrier(mem_flags::mem_threadgroup);
1904
1905        max_val = buf[tiisg];
1906        max_val = simd_max(max_val);
1907    }
1908
1909    // parallel sum
1910    float4 lsum4 = 0.0f;
1911    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1912        const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1913        lsum4 += exp_psrc4;
1914        pdst4[i00] = exp_psrc4;
1915    }
1916
1917    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
1918
1919    // This barrier fixes a failing test
1920    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1921    threadgroup_barrier(mem_flags::mem_none);
1922
1923    float sum = simd_sum(lsum);
1924
1925    if (tptg.x > N_SIMDWIDTH) {
1926        if (sgitg == 0) {
1927            buf[tiisg] = 0.0f;
1928        }
1929
1930        threadgroup_barrier(mem_flags::mem_threadgroup);
1931
1932        if (tiisg == 0) {
1933            buf[sgitg] = sum;
1934        }
1935
1936        threadgroup_barrier(mem_flags::mem_threadgroup);
1937
1938        sum = buf[tiisg];
1939        sum = simd_sum(sum);
1940    }
1941
1942    if (psrc2) {
1943        sum += exp(psrc2[i02] - max_val);
1944    }
1945
1946    const float inv_sum = 1.0f/sum;
1947
1948    for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1949        pdst4[i00] *= inv_sum;
1950    }
1951}
1952
1953typedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;
1954typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
1955
1956template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;
1957template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;
1958template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
1959template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
1960
1961// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
1962kernel void kernel_ssm_conv_f32_f32(
1963        constant ggml_metal_kargs_ssm_conv & args,
1964        device const  void * src0,
1965        device const  void * src1,
1966        device       float * dst,
1967        uint3 tgpig[[threadgroup_position_in_grid]],
1968        uint3 tpitg[[thread_position_in_threadgroup]],
1969        uint3   ntg[[threads_per_threadgroup]]) {
1970    const int64_t ir = tgpig.x;
1971    const int64_t i2 = tgpig.y;
1972    const int64_t i3 = tgpig.z;
1973
1974    const int64_t nc  = args.ne10;
1975  //const int64_t ncs = args.ne00;
1976  //const int64_t nr  = args.ne01;
1977  //const int64_t n_t = args.ne1;
1978  //const int64_t n_s = args.ne2;
1979
1980    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
1981    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
1982    device       float * x = (device       float *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
1983
1984    float sumf = 0.0f;
1985
1986    for (int64_t i0 = 0; i0 < nc; ++i0) {
1987        sumf += s[i0] * c[i0];
1988    }
1989
1990    x[0] = sumf;
1991}
1992
1993kernel void kernel_ssm_conv_f32_f32_4(
1994        constant ggml_metal_kargs_ssm_conv & args,
1995        device const  void * src0,
1996        device const  void * src1,
1997        device       float * dst,
1998        uint3 tgpig[[threadgroup_position_in_grid]],
1999        uint3 tpitg[[thread_position_in_threadgroup]],
2000        uint3   ntg[[threads_per_threadgroup]]) {
2001    const int64_t ir = tgpig.x;
2002    const int64_t i2 = tgpig.y;
2003    const int64_t i3 = tgpig.z;
2004
2005    const int64_t nc  = args.ne10;
2006  //const int64_t ncs = args.ne00;
2007  //const int64_t nr  = args.ne01;
2008  //const int64_t n_t = args.ne1;
2009  //const int64_t n_s = args.ne2;
2010
2011    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2012    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2013    device       float  * x = (device       float  *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
2014
2015    float sumf = 0.0f;
2016
2017    for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2018        sumf += dot(s[i0], c[i0]);
2019    }
2020
2021    x[0] = sumf;
2022}
2023
2024constant short FC_ssm_conv_bs   [[function_constant(FC_SSM_CONV + 0)]];
2025
2026// Batched version: each threadgroup processes multiple tokens for better efficiency
2027// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
2028kernel void kernel_ssm_conv_f32_f32_batched(
2029        constant ggml_metal_kargs_ssm_conv & args,
2030        device const  void * src0,
2031        device const  void * src1,
2032        device       float * dst,
2033        uint3 tgpig[[threadgroup_position_in_grid]],
2034        uint3 tpitg[[thread_position_in_threadgroup]],
2035        uint3   ntg[[threads_per_threadgroup]]) {
2036    // tgpig.x = row index (ir)
2037    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2038    // tgpig.z = sequence index (i3)
2039    // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2040    const short BATCH_SIZE = FC_ssm_conv_bs;
2041
2042    const int64_t ir      = tgpig.x;
2043    const int64_t i2_base = tgpig.y * BATCH_SIZE;
2044    const int64_t i3      = tgpig.z;
2045    const int64_t i2_off  = tpitg.x;
2046    const int64_t i2      = i2_base + i2_off;
2047
2048    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)
2049    const int64_t n_t = args.ne1;   // number of tokens
2050
2051    // Bounds check for partial batches at the end
2052    if (i2 >= n_t) {
2053        return;
2054    }
2055
2056    // Load conv weights (shared across all tokens for this row)
2057    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
2058
2059    // Load source for this specific token
2060    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2061
2062    // Output location for this token
2063    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2064
2065    float sumf = 0.0f;
2066    for (int64_t i0 = 0; i0 < nc; ++i0) {
2067        sumf += s[i0] * c[i0];
2068    }
2069
2070    x[0] = sumf;
2071}
2072
2073kernel void kernel_ssm_conv_f32_f32_batched_4(
2074        constant ggml_metal_kargs_ssm_conv & args,
2075        device const  void * src0,
2076        device const  void * src1,
2077        device       float * dst,
2078        uint3 tgpig[[threadgroup_position_in_grid]],
2079        uint3 tpitg[[thread_position_in_threadgroup]],
2080        uint3   ntg[[threads_per_threadgroup]]) {
2081    // tgpig.x = row index (ir)
2082    // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2083    // tgpig.z = sequence index (i3)
2084    // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2085    const short BATCH_SIZE = FC_ssm_conv_bs;
2086
2087    const int64_t ir      = tgpig.x;
2088    const int64_t i2_base = tgpig.y * BATCH_SIZE;
2089    const int64_t i3      = tgpig.z;
2090    const int64_t i2_off  = tpitg.x;
2091    const int64_t i2      = i2_base + i2_off;
2092
2093    const int64_t nc  = args.ne10;  // conv kernel size (typically 4)
2094    const int64_t n_t = args.ne1;   // number of tokens
2095
2096    // Bounds check for partial batches at the end
2097    if (i2 >= n_t) {
2098        return;
2099    }
2100
2101    // Load conv weights (shared across all tokens for this row)
2102    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2103
2104    // Load source for this specific token
2105    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2106
2107    // Output location for this token
2108    device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2109
2110    float sumf = 0.0f;
2111    for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2112        sumf += dot(s[i0], c[i0]);
2113    }
2114
2115    x[0] = sumf;
2116}
2117
2118// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2119// Optimized version: reduces redundant memory loads by having one thread load shared values
2120kernel void kernel_ssm_scan_f32(
2121        constant ggml_metal_kargs_ssm_scan & args,
2122        device const void * src0,
2123        device const void * src1,
2124        device const void * src2,
2125        device const void * src3,
2126        device const void * src4,
2127        device const void * src5,
2128        device const void * src6,
2129        device      float * dst,
2130        threadgroup float * shared [[threadgroup(0)]],
2131        uint3   tgpig[[threadgroup_position_in_grid]],
2132        ushort3 tpitg[[thread_position_in_threadgroup]],
2133        ushort  sgitg[[simdgroup_index_in_threadgroup]],
2134        ushort  tiisg[[thread_index_in_simdgroup]],
2135        ushort  sgptg[[simdgroups_per_threadgroup]],
2136        uint3    tgpg[[threadgroups_per_grid]]) {
2137    constexpr short NW = N_SIMDWIDTH;
2138
2139    // Shared memory layout:
2140    // [0..sgptg*NW-1]: partial sums for reduction (existing)
2141    // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2142    // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2143    threadgroup float * shared_sums = shared;
2144    threadgroup float * shared_x_dt = shared + sgptg * NW;
2145    threadgroup float * shared_dA   = shared + sgptg * NW + sgptg;
2146
2147    shared_sums[tpitg.x] = 0.0f;
2148
2149    const int32_t i0 = tpitg.x;
2150    const int32_t i1 = tgpig.x;
2151    const int32_t ir = tgpig.y; // current head
2152    const int32_t i3 = tgpig.z; // current seq
2153
2154    const int32_t nc  = args.d_state;
2155    const int32_t nr  = args.d_inner;
2156    const int32_t nh  = args.n_head;
2157    const int32_t ng  = args.n_group;
2158    const int32_t n_t = args.n_seq_tokens;
2159
2160    const int32_t s_off = args.s_off;
2161
2162    device const int32_t * ids = (device const int32_t *) src6;
2163
2164    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2165    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
2166
2167    const int32_t i = i0 + i1*nc;
2168    const int32_t g = ir / (nh / ng); // repeat_interleave
2169
2170    float s0 = s0_buff[i];
2171    float s  = 0.0f;
2172
2173    device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
2174
2175    const float A0 = A[i0%args.ne30];
2176
2177    device const float * x  = (device const float *)((device const char *) src1 + i1*args.nb10  + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
2178    device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20  + i3*args.nb22);                // {nh, nt, ns}
2179    device const float * B  = (device const float *)((device const char *) src4 +  g*args.nb41  + i3*args.nb43);                // {d_state, ng, nt, ns}
2180    device const float * C  = (device const float *)((device const char *) src5 +  g*args.nb51  + i3*args.nb53);                // {d_state, ng, nt, ns}
2181
2182    device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
2183
2184    for (int i2 = 0; i2 < n_t; i2 += sgptg) {
2185        threadgroup_barrier(mem_flags::mem_threadgroup);
2186
2187        // Pre-compute x_dt and dA for this batch of tokens
2188        // Only first sgptg threads do the loads and expensive math
2189        if (i0 < sgptg && i2 + i0 < n_t) {
2190            // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2191            device const float * x_t  = x  + i0 * args.ns12;
2192            device const float * dt_t = dt + i0 * args.ns21;
2193
2194            const float dt0  = dt_t[0];
2195            const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2196            shared_x_dt[i0] = x_t[0] * dtsp;
2197            shared_dA[i0]   = dtsp;  // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2198        }
2199
2200        threadgroup_barrier(mem_flags::mem_threadgroup);
2201
2202        for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2203            const float x_dt = shared_x_dt[t];
2204            const float dA   = exp(shared_dA[t] * A0);
2205
2206            s = (s0 * dA) + (B[i0] * x_dt);
2207
2208            const float sumf = simd_sum(s * C[i0]);
2209
2210            if (tiisg == 0) {
2211                shared_sums[t*NW + sgitg] = sumf;
2212            }
2213
2214            // recurse
2215            s0 = s;
2216
2217            B  += args.ns42;
2218            C  += args.ns52;
2219        }
2220
2221        // Advance pointers for next batch
2222        x  += sgptg * args.ns12;
2223        dt += sgptg * args.ns21;
2224
2225        threadgroup_barrier(mem_flags::mem_threadgroup);
2226
2227        const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
2228
2229        if (tiisg == 0 && i2 + sgitg < n_t) {
2230            y[sgitg*nh*nr] = sumf;
2231        }
2232
2233        y += sgptg*nh*nr;
2234    }
2235
2236    s_buff[i] = s;
2237}
2238
2239kernel void kernel_rwkv_wkv6_f32(
2240    device const float * k,
2241    device const float * v,
2242    device const float * r,
2243    device const float * tf,
2244    device const float * td,
2245    device const float * state_in,
2246    device       float * dst,
2247    constant    uint & B,
2248    constant    uint & T,
2249    constant    uint & C,
2250    constant    uint & H,
2251    uint3 tgpig[[threadgroup_position_in_grid]],
2252    uint3 tpitg[[thread_position_in_threadgroup]],
2253    uint3   ntg[[threads_per_threadgroup]])  {
2254
2255    const uint head_size = 64; // TODO: support head_size = 128
2256    const uint batch_id = tgpig.x / H;
2257    const uint head_id = tgpig.x % H;
2258    const uint tid = tpitg.x;
2259
2260    if (batch_id >= B || head_id >= H) {
2261        return;
2262    }
2263
2264    const uint state_size = C * head_size;
2265    const uint n_seq_tokens = T / B;
2266
2267    threadgroup float _k[head_size];
2268    threadgroup float _r[head_size];
2269    threadgroup float _tf[head_size];
2270    threadgroup float _td[head_size];
2271
2272    float state[head_size];
2273
2274    for (uint i = 0; i < head_size; i++) {
2275        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
2276                          + i * head_size + tid];
2277    }
2278
2279    threadgroup_barrier(mem_flags::mem_threadgroup);
2280    _tf[tid] = tf[head_id * head_size + tid];
2281    threadgroup_barrier(mem_flags::mem_threadgroup);
2282
2283    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
2284    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
2285
2286    for (uint t = start_t; t < end_t; t += C) {
2287        threadgroup_barrier(mem_flags::mem_threadgroup);
2288        _k[tid] = k[t];
2289        _r[tid] = r[t];
2290        _td[tid] = td[t];
2291        threadgroup_barrier(mem_flags::mem_threadgroup);
2292
2293        const float v_val = v[t];
2294        float y = 0.0;
2295
2296        for (uint j = 0; j < head_size; j += 4) {
2297            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2298            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2299            float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
2300            float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
2301            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2302
2303            float4 kv = k_vec * v_val;
2304
2305            float4 temp = tf_vec * kv + s_vec;
2306            y += dot(r_vec, temp);
2307
2308            s_vec = s_vec * td_vec + kv;
2309            state[j]   = s_vec[0];
2310            state[j+1] = s_vec[1];
2311            state[j+2] = s_vec[2];
2312            state[j+3] = s_vec[3];
2313        }
2314
2315        dst[t] = y;
2316    }
2317
2318    for (uint i = 0; i < head_size; i++) {
2319        dst[T * C + batch_id * state_size + head_id * head_size * head_size
2320            + i * head_size + tid] = state[i];
2321    }
2322}
2323
2324kernel void kernel_rwkv_wkv7_f32(
2325    device const float * r,
2326    device const float * w,
2327    device const float * k,
2328    device const float * v,
2329    device const float * a,
2330    device const float * b,
2331    device const float * state_in,
2332    device       float * dst,
2333    constant    uint & B,
2334    constant    uint & T,
2335    constant    uint & C,
2336    constant    uint & H,
2337    uint3 tgpig[[threadgroup_position_in_grid]],
2338    uint3 tpitg[[thread_position_in_threadgroup]],
2339    uint3   ntg[[threads_per_threadgroup]])  {
2340
2341    const uint head_size = 64; // TODO: support head_size = 128
2342    const uint batch_id = tgpig.x / H;
2343    const uint head_id = tgpig.x % H;
2344    const uint tid = tpitg.x;
2345
2346    if (batch_id >= B || head_id >= H) {
2347        return;
2348    }
2349
2350    const uint state_size = C * head_size;
2351    const uint n_seq_tokens = T / B;
2352
2353    threadgroup float _r[head_size];
2354    threadgroup float _w[head_size];
2355    threadgroup float _k[head_size];
2356    threadgroup float _a[head_size];
2357    threadgroup float _b[head_size];
2358
2359    float state[head_size];
2360
2361    for (uint i = 0; i < head_size; i++) {
2362        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
2363                          + tid * head_size + i];
2364    }
2365
2366    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
2367    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
2368
2369    for (uint t = start_t; t < end_t; t += C) {
2370        threadgroup_barrier(mem_flags::mem_threadgroup);
2371        _r[tid] = r[t];
2372        _w[tid] = w[t];
2373        _k[tid] = k[t];
2374        _a[tid] = a[t];
2375        _b[tid] = b[t];
2376        threadgroup_barrier(mem_flags::mem_threadgroup);
2377
2378        const float v_val = v[t];
2379        float y = 0.0, sa = 0.0;
2380
2381        float4 sa_vec(0.0);
2382
2383        for (uint j = 0; j < head_size; j += 4) {
2384            float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
2385            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2386            sa_vec += a_vec * s_vec;
2387        }
2388        sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
2389
2390        for (uint j = 0; j < head_size; j += 4) {
2391            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2392            float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
2393            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2394            float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
2395            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2396
2397            float4 kv = k_vec * v_val;
2398
2399            s_vec = s_vec * w_vec + kv + sa * b_vec;
2400            y += dot(s_vec, r_vec);
2401
2402            state[j]   = s_vec[0];
2403            state[j+1] = s_vec[1];
2404            state[j+2] = s_vec[2];
2405            state[j+3] = s_vec[3];
2406        }
2407
2408        dst[t] = y;
2409    }
2410
2411    for (uint i = 0; i < head_size; i++) {
2412        dst[T * C + batch_id * state_size + head_id * head_size * head_size
2413            + tid * head_size + i] = state[i];
2414    }
2415}
2416
2417constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
2418constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
2419constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
2420
2421kernel void kernel_solve_tri_f32(
2422        constant ggml_metal_kargs_solve_tri & args,
2423        device   const char * src0,
2424        device   const char * src1,
2425        device         char * dst,
2426        threadgroup    char * shmem [[threadgroup(0)]],
2427        ushort3 tgpig[[threadgroup_position_in_grid]],
2428        ushort  sgitg[[simdgroup_index_in_threadgroup]],
2429        ushort  tiisg[[thread_index_in_simdgroup]],
2430        ushort3   ntg[[threads_per_threadgroup]]) {
2431    constexpr short NW = N_SIMDWIDTH;
2432
2433    const short NSG = FC_solve_tri_nsg;
2434    const short N   = FC_solve_tri_n;
2435    const short K   = FC_solve_tri_k;
2436    const short NP  = PAD2(N, NW);
2437
2438    const int32_t ne02 = args.ne02;
2439    const int32_t ne03 = args.ne03;
2440
2441    const int32_t i03 = tgpig.z;
2442    const int32_t i02 = tgpig.y;
2443    const int32_t i01 = tgpig.x*NSG + sgitg;
2444
2445    threadgroup float * sh0 = (threadgroup float *) shmem;
2446
2447    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
2448    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
2449    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;
2450
2451    for (short rr = 0; rr < N; rr += NSG) {
2452        threadgroup_barrier(mem_flags::mem_threadgroup);
2453
2454        {
2455            threadgroup float * sh0_cur = sh0 + sgitg*NP;
2456
2457            for (short t = 0; t*NW < N; ++t) {
2458                const short idx = t*NW + tiisg;
2459                sh0_cur[idx] = src0_ptr[idx];
2460            }
2461
2462            src0_ptr += NSG*N;
2463        }
2464
2465        threadgroup_barrier(mem_flags::mem_threadgroup);
2466
2467        if (i01 >= args.ne10) {
2468            continue;
2469        }
2470
2471        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
2472            const short r = rr + ir;
2473
2474            threadgroup float * sh0_cur = sh0 + ir*NP;
2475
2476            float sum = 0.0f;
2477
2478            for (short t = 0; t*NW < r; ++t) {
2479                const short idx = t*NW + tiisg;
2480                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
2481            }
2482
2483            sum = simd_sum(sum);
2484
2485            if (tiisg == 0) {
2486                const float diag = sh0_cur[r];
2487
2488                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
2489            }
2490        }
2491    }
2492}
2493
2494kernel void kernel_argmax_f32(
2495        constant ggml_metal_kargs_argmax & args,
2496        device   const char * src0,
2497        device         char * dst,
2498        threadgroup    char * shmem [[threadgroup(0)]],
2499        uint  tgpig[[threadgroup_position_in_grid]],
2500        uint  tpitg[[thread_position_in_threadgroup]],
2501        uint  sgitg[[simdgroup_index_in_threadgroup]],
2502        uint  tiisg[[thread_index_in_simdgroup]],
2503        uint    ntg[[threads_per_threadgroup]]) {
2504    device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
2505
2506    float   lmax = -INFINITY;
2507    int32_t larg = -1;
2508
2509    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
2510        if (x_row[i00] > lmax) {
2511            lmax = x_row[i00];
2512            larg = i00;
2513        }
2514    }
2515
2516    // find the argmax value in the block
2517    float max_val = simd_max(lmax);
2518    int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
2519
2520    device int32_t * dst_i32 = (device int32_t *) dst;
2521
2522    threadgroup   float * shared_maxval = (threadgroup   float *) shmem;
2523    threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
2524
2525    if (ntg > N_SIMDWIDTH) {
2526        if (sgitg == 0) {
2527            shared_maxval[tiisg] = -INFINITY;
2528            shared_argmax[tiisg] = -1;
2529        }
2530
2531        threadgroup_barrier(mem_flags::mem_threadgroup);
2532
2533        if (tiisg == 0) {
2534            shared_maxval[sgitg] = max_val;
2535            shared_argmax[sgitg] = arg_val;
2536        }
2537
2538        threadgroup_barrier(mem_flags::mem_threadgroup);
2539
2540        max_val = shared_maxval[tiisg];
2541        arg_val = shared_argmax[tiisg];
2542
2543        float max_val_reduced   = simd_max(max_val);
2544        int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
2545
2546        dst_i32[tgpig] = arg_val_reduced;
2547
2548        return;
2549    }
2550
2551    dst_i32[tgpig] = arg_val;
2552}
2553
2554// F == 1 : norm (no fuse)
2555// F == 2 : norm + mul
2556// F == 3 : norm + mul + add
2557template <typename T, short F>
2558kernel void kernel_norm_fuse_impl(
2559        constant ggml_metal_kargs_norm & args,
2560        device const char * src0,
2561        device const char * src1_0,
2562        device const char * src1_1,
2563        device       char * dst,
2564        threadgroup float * shmem_f32 [[threadgroup(0)]],
2565        uint3   tgpig[[threadgroup_position_in_grid]],
2566        ushort3 tpitg[[thread_position_in_threadgroup]],
2567        ushort  sgitg[[simdgroup_index_in_threadgroup]],
2568        ushort  tiisg[[thread_index_in_simdgroup]],
2569        ushort3   ntg[[threads_per_threadgroup]]) {
2570    if (sgitg == 0) {
2571        shmem_f32[tiisg] = 0.0f;
2572    }
2573
2574    const int i01 = tgpig.x;
2575    const int i02 = tgpig.y;
2576    const int i03 = tgpig.z;
2577
2578    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2579
2580    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2581    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
2582
2583    T sumft(0.0f);
2584
2585    float sumf = 0.0f;
2586
2587    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2588        sumft += x[i00];
2589    }
2590    sumf = dot(sumft, T(1.0f));
2591    sumf = simd_sum(sumf);
2592
2593    threadgroup_barrier(mem_flags::mem_threadgroup);
2594
2595    if (tiisg == 0) {
2596        shmem_f32[sgitg] = sumf;
2597    }
2598
2599    threadgroup_barrier(mem_flags::mem_threadgroup);
2600
2601    sumf = shmem_f32[tiisg];
2602    sumf = simd_sum(sumf);
2603
2604    const float mean = sumf/args.ne00;
2605
2606    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2607
2608    sumf = 0.0f;
2609    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2610        y[i00] = x[i00] - mean;
2611        sumf += dot(y[i00], y[i00]);
2612    }
2613    sumf = simd_sum(sumf);
2614
2615    threadgroup_barrier(mem_flags::mem_threadgroup);
2616
2617    if (tiisg == 0) {
2618        shmem_f32[sgitg] = sumf;
2619    }
2620
2621    threadgroup_barrier(mem_flags::mem_threadgroup);
2622
2623    sumf = shmem_f32[tiisg];
2624    sumf = simd_sum(sumf);
2625
2626    const float variance = sumf/args.ne00;
2627
2628    const float scale = 1.0f/sqrt(variance + args.eps);
2629    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2630        if (F == 1) {
2631            y[i00] = (y[i00]*scale);
2632        }
2633        if (F == 2) {
2634            y[i00] = (y[i00]*scale)*f0[i00];
2635        }
2636        if (F == 3) {
2637            y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
2638        }
2639    }
2640}
2641
2642typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
2643
2644template [[host_name("kernel_norm_f32")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
2645template [[host_name("kernel_norm_mul_f32")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
2646template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
2647
2648template [[host_name("kernel_norm_f32_4")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
2649template [[host_name("kernel_norm_mul_f32_4")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
2650template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
2651
2652// F == 1 : rms_norm (no fuse)
2653// F == 2 : rms_norm + mul
2654// F == 3 : rms_norm + mul + add
2655template <typename T, short F>
2656kernel void kernel_rms_norm_fuse_impl(
2657        constant ggml_metal_kargs_norm & args,
2658        device const char * src0,
2659        device const char * src1_0,
2660        device const char * src1_1,
2661        device       char * dst,
2662        threadgroup float * shmem_f32 [[threadgroup(0)]],
2663        uint3   tgpig[[threadgroup_position_in_grid]],
2664        ushort3 tpitg[[thread_position_in_threadgroup]],
2665        ushort  sgitg[[simdgroup_index_in_threadgroup]],
2666        ushort  tiisg[[thread_index_in_simdgroup]],
2667        ushort3   ntg[[threads_per_threadgroup]]) {
2668    if (sgitg == 0) {
2669        shmem_f32[tiisg] = 0.0f;
2670    }
2671
2672    const int i01 = tgpig.x;
2673    const int i02 = tgpig.y;
2674    const int i03 = tgpig.z;
2675
2676    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2677
2678    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2679    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
2680
2681    float sumf = 0.0f;
2682
2683    // parallel sum
2684    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2685        sumf += dot(x[i00], x[i00]);
2686    }
2687    sumf = simd_sum(sumf);
2688
2689    threadgroup_barrier(mem_flags::mem_threadgroup);
2690
2691    if (tiisg == 0) {
2692        shmem_f32[sgitg] = sumf;
2693    }
2694
2695    threadgroup_barrier(mem_flags::mem_threadgroup);
2696
2697    sumf = shmem_f32[tiisg];
2698    sumf = simd_sum(sumf);
2699
2700    const float mean  = sumf/args.ne00;
2701    const float scale = 1.0f/sqrt(mean + args.eps);
2702
2703    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2704    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2705        if (F == 1) {
2706            y[i00] = (x[i00]*scale);
2707        }
2708        if (F == 2) {
2709            y[i00] = (x[i00]*scale)*f0[i00];
2710        }
2711        if (F == 3) {
2712            y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2713        }
2714    }
2715}
2716
2717typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
2718
2719template [[host_name("kernel_rms_norm_f32")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
2720template [[host_name("kernel_rms_norm_mul_f32")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
2721template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
2722
2723template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
2724template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
2725template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
2726
2727template <typename T0, typename T>
2728kernel void kernel_l2_norm_impl(
2729        constant ggml_metal_kargs_l2_norm & args,
2730        device const char * src0,
2731        device       char * dst,
2732        threadgroup float * shmem_f32 [[threadgroup(0)]],
2733        uint3   tgpig[[threadgroup_position_in_grid]],
2734        ushort3 tpitg[[thread_position_in_threadgroup]],
2735        ushort  sgitg[[simdgroup_index_in_threadgroup]],
2736        ushort  tiisg[[thread_index_in_simdgroup]],
2737        ushort3   ntg[[threads_per_threadgroup]]) {
2738    const int i03 = tgpig.z;
2739    const int i02 = tgpig.y;
2740    const int i01 = tgpig.x;
2741
2742    if (sgitg == 0) {
2743        shmem_f32[tiisg] = 0.0f;
2744    }
2745
2746    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2747    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);
2748
2749    float sumf = 0.0f;
2750
2751    // parallel sum
2752    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2753        sumf += dot(x[i00], x[i00]);
2754    }
2755    sumf = simd_sum(sumf);
2756
2757    threadgroup_barrier(mem_flags::mem_threadgroup);
2758
2759    if (tiisg == 0) {
2760        shmem_f32[sgitg] = sumf;
2761    }
2762
2763    threadgroup_barrier(mem_flags::mem_threadgroup);
2764
2765    sumf = shmem_f32[tiisg];
2766    sumf = simd_sum(sumf);
2767
2768    const float scale = 1.0f/sqrt(max(sumf, args.eps));
2769
2770    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2771        y[i00] = x[i00] * scale;
2772    }
2773}
2774
2775typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
2776
2777template [[host_name("kernel_l2_norm_f32_f32")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl<float,  float>;
2778template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
2779
2780kernel void kernel_group_norm_f32(
2781        constant ggml_metal_kargs_group_norm & args,
2782        device const float * src0,
2783        device       float * dst,
2784        threadgroup float  * buf [[threadgroup(0)]],
2785        uint tgpig[[threadgroup_position_in_grid]],
2786        uint tpitg[[thread_position_in_threadgroup]],
2787        uint sgitg[[simdgroup_index_in_threadgroup]],
2788        uint tiisg[[thread_index_in_simdgroup]],
2789        uint   ntg[[threads_per_threadgroup]]) {
2790    const int64_t ne = args.ne00*args.ne01*args.ne02;
2791    const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
2792
2793    int start = tgpig * gs;
2794    int end   = start + gs;
2795
2796    start += tpitg;
2797
2798    if (end >= ne) {
2799        end = ne;
2800    }
2801
2802    float tmp = 0.0f; // partial sum for thread in warp
2803
2804    for (int j = start; j < end; j += ntg) {
2805        tmp += src0[j];
2806    }
2807
2808    threadgroup_barrier(mem_flags::mem_threadgroup);
2809    tmp = simd_sum(tmp);
2810    if (ntg > N_SIMDWIDTH) {
2811        if (sgitg == 0) {
2812            buf[tiisg] = 0.0f;
2813        }
2814
2815        threadgroup_barrier(mem_flags::mem_threadgroup);
2816
2817        if (tiisg == 0) {
2818            buf[sgitg] = tmp;
2819        }
2820
2821        threadgroup_barrier(mem_flags::mem_threadgroup);
2822
2823        tmp = buf[tiisg];
2824        tmp = simd_sum(tmp);
2825    }
2826
2827    const float mean = tmp / gs;
2828    tmp = 0.0f;
2829
2830    for (int j = start; j < end; j += ntg) {
2831        float xi = src0[j] - mean;
2832        dst[j] = xi;
2833        tmp += xi * xi;
2834    }
2835
2836    tmp = simd_sum(tmp);
2837    if (ntg > N_SIMDWIDTH) {
2838        if (sgitg == 0) {
2839            buf[tiisg] = 0.0f;
2840        }
2841
2842        threadgroup_barrier(mem_flags::mem_threadgroup);
2843
2844        if (tiisg == 0) {
2845            buf[sgitg] = tmp;
2846        }
2847
2848        threadgroup_barrier(mem_flags::mem_threadgroup);
2849
2850        tmp = buf[tiisg];
2851        tmp = simd_sum(tmp);
2852    }
2853
2854    const float variance = tmp / gs;
2855    const float scale = 1.0f/sqrt(variance + args.eps);
2856    for (int j = start; j < end; j += ntg) {
2857        dst[j] *= scale;
2858    }
2859}
2860
2861// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
2862// il indicates where the q4 quants begin (0 or QK4_0/4)
2863// we assume that the yl's have been multiplied with the appropriate scale factor
2864// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2865inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
2866    float d = qb_curr->d;
2867
2868    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2869
2870    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
2871
2872    for (int i = 0; i < 8; i += 2) {
2873        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
2874        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
2875        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
2876        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
2877    }
2878
2879    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
2880}
2881
2882// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
2883// il indicates where the q4 quants begin (0 or QK4_0/4)
2884// we assume that the yl's have been multiplied with the appropriate scale factor
2885// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2886inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
2887    float d = qb_curr->d;
2888    float m = qb_curr->m;
2889
2890    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2891
2892    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
2893
2894    for (int i = 0; i < 8; i+=2) {
2895        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
2896        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
2897        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
2898        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
2899    }
2900
2901    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
2902}
2903
2904// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
2905// il indicates where the q5 quants begin (0 or QK5_0/4)
2906// we assume that the yl's have been multiplied with the appropriate scale factor
2907// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2908inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
2909    float d = qb_curr->d;
2910
2911    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2912
2913    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
2914           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
2915
2916    for (int i = 0; i < 8; i+=2) {
2917        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
2918        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
2919        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
2920        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
2921    }
2922
2923    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
2924}
2925
2926// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
2927// il indicates where the q5 quants begin (0 or QK5_1/4)
2928// we assume that the yl's have been multiplied with the appropriate scale factor
2929// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2930inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
2931    float d = qb_curr->d;
2932    float m = qb_curr->m;
2933
2934    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2935
2936    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
2937           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
2938
2939    for (int i = 0; i < 8; i+=2) {
2940        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
2941        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
2942        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
2943        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
2944    }
2945
2946    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
2947}
2948
2949template<short NR0>
2950static inline void helper_mv_reduce_and_write(
2951        device float * dst_f32,
2952        float sumf[NR0],
2953        const int r0,
2954        const int ne01,
2955        ushort tiisg,
2956        ushort sgitg,
2957        threadgroup char * shmem) {
2958    constexpr short NW = N_SIMDWIDTH;
2959
2960    threadgroup float * shmem_f32[NR0];
2961
2962    for (short row = 0; row < NR0; ++row) {
2963        shmem_f32[row] = (threadgroup float *) shmem + NW*row;
2964
2965        if (sgitg == 0) {
2966            shmem_f32[row][tiisg] = 0.0f;
2967        }
2968
2969        sumf[row] = simd_sum(sumf[row]);
2970    }
2971
2972    threadgroup_barrier(mem_flags::mem_threadgroup);
2973
2974    for (short row = 0; row < NR0; ++row) {
2975        if (tiisg == 0) {
2976            shmem_f32[row][sgitg] = sumf[row];
2977        }
2978    }
2979
2980    threadgroup_barrier(mem_flags::mem_threadgroup);
2981
2982    for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
2983        float tot = simd_sum(shmem_f32[row][tiisg]);
2984
2985        if (tiisg == 0 && sgitg == 0) {
2986            dst_f32[r0 + row] = tot;
2987        }
2988    }
2989}
2990
2991constant short FC_mul_mv_nsg   [[function_constant(FC_MUL_MV + 0)]];
2992constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
2993
2994template<typename block_q_type, short NR0, typename args_t>
2995void mul_vec_q_n_f32_impl(
2996        args_t args,
2997        device const char * src0,
2998        device const char * src1,
2999        device       char * dst,
3000        threadgroup  char * shmem,
3001        uint3  tgpig,
3002        ushort tiisg,
3003        ushort sgitg) {
3004    const short NSG = FC_mul_mv_nsg;
3005
3006    constexpr short NW = N_SIMDWIDTH;
3007    constexpr short NQ = 16;
3008
3009    const int nb = args.ne00/QK4_0;
3010
3011    const int r0 = (tgpig.x*NSG + sgitg)*NR0;
3012  //const int r0 =  tgpig.x*NR0;
3013    const int r1 =  tgpig.y;
3014    const int im =  tgpig.z;
3015
3016    const uint i12 = im%args.ne12;
3017    const uint i13 = im/args.ne12;
3018
3019  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3020    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3021
3022  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
3023    device const float        * y = (device const float        *) (src1 + offset1);
3024
3025    // pointers to src0 rows
3026    device const block_q_type * ax[NR0];
3027    FOR_UNROLL (int row = 0; row < NR0; ++row) {
3028        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3029
3030        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
3031    }
3032
3033    float sumf[NR0] = {0.f};
3034
3035    const short ix = (tiisg/(NW/NQ));
3036    const short il = (tiisg%(NW/NQ))*8;
3037
3038    //const int ib0 = sgitg*NQ + ix;
3039    const int ib0 = ix;
3040
3041    float yl[16]; // src1 vector cache
3042
3043    //device const float * yb = y + ix*QK4_0 + il;
3044    device const float * yb = y + ib0*QK4_0 + il;
3045
3046    // each thread in a SIMD group deals with half a block.
3047    //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3048    for (int ib = ib0; ib < nb; ib += NQ) {
3049        float sumy[2] = { 0.f, 0.f };
3050
3051        FOR_UNROLL (short i = 0; i < 8; i += 2) {
3052            sumy[0]  += yb[i +  0] + yb[i +  1];
3053            yl[i + 0] = yb[i +  0];
3054            yl[i + 1] = yb[i +  1]/256.f;
3055
3056            sumy[1]  += yb[i + 16] + yb[i + 17];
3057            yl[i + 8] = yb[i + 16]/16.f;
3058            yl[i + 9] = yb[i + 17]/4096.f;
3059        }
3060
3061        FOR_UNROLL (short row = 0; row < NR0; row++) {
3062            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
3063        }
3064
3065        yb += QK4_0 * 16;
3066        //yb += NSG*NQ*QK4_0;
3067    }
3068
3069    device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
3070
3071    //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3072
3073    for (int row = 0; row < NR0; ++row) {
3074        const float tot = simd_sum(sumf[row]);
3075
3076        if (tiisg == 0 && r0 + row < args.ne01) {
3077            dst_f32[r0 + row] = tot;
3078        }
3079    }
3080}
3081
3082kernel void kernel_mul_mv_q4_0_f32(
3083        constant ggml_metal_kargs_mul_mv & args,
3084        device const char * src0,
3085        device const char * src1,
3086        device       char * dst,
3087        threadgroup  char * shmem [[threadgroup(0)]],
3088        uint3  tgpig[[threadgroup_position_in_grid]],
3089        ushort tiisg[[thread_index_in_simdgroup]],
3090        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3091    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3092}
3093
3094kernel void kernel_mul_mv_q4_1_f32(
3095        constant ggml_metal_kargs_mul_mv & args,
3096        device const char * src0,
3097        device const char * src1,
3098        device       char * dst,
3099        threadgroup  char * shmem [[threadgroup(0)]],
3100        uint3  tgpig[[threadgroup_position_in_grid]],
3101        ushort tiisg[[thread_index_in_simdgroup]],
3102        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3103     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3104}
3105
3106kernel void kernel_mul_mv_q5_0_f32(
3107        constant ggml_metal_kargs_mul_mv & args,
3108        device const char * src0,
3109        device const char * src1,
3110        device       char * dst,
3111        threadgroup  char * shmem [[threadgroup(0)]],
3112        uint3  tgpig[[threadgroup_position_in_grid]],
3113        ushort tiisg[[thread_index_in_simdgroup]],
3114        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3115    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3116}
3117
3118kernel void kernel_mul_mv_q5_1_f32(
3119        constant ggml_metal_kargs_mul_mv & args,
3120        device const char * src0,
3121        device const char * src1,
3122        device       char * dst,
3123        threadgroup  char * shmem [[threadgroup(0)]],
3124        uint3  tgpig[[threadgroup_position_in_grid]],
3125        ushort tiisg[[thread_index_in_simdgroup]],
3126        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3127    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3128}
3129
3130template<short NR0, typename args_t>
3131void kernel_mul_mv_q8_0_f32_impl(
3132        args_t args,
3133        device const char * src0,
3134        device const char * src1,
3135        device       char * dst,
3136        threadgroup  char * shmem,
3137        uint3  tgpig,
3138        ushort tiisg,
3139        ushort sgitg) {
3140    const short NSG = FC_mul_mv_nsg;
3141
3142    constexpr short NW = N_SIMDWIDTH;
3143    constexpr short NQ = 8;
3144
3145    const int nb = args.ne00/QK8_0;
3146
3147    const int r0 = tgpig.x*NR0;
3148    const int r1 = tgpig.y;
3149    const int im = tgpig.z;
3150
3151    const uint i12 = im%args.ne12;
3152    const uint i13 = im/args.ne12;
3153
3154  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3155    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3156
3157  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
3158    device const float      * y = (device const float      *) (src1 + offset1);
3159
3160    // pointers to src0 rows
3161    device const block_q8_0 * ax[NR0];
3162    FOR_UNROLL (short row = 0; row < NR0; ++row) {
3163        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3164
3165        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
3166    }
3167
3168    float sumf[NR0] = { 0.f };
3169
3170    const short ix = tiisg/(NW/NQ);
3171    const short il = tiisg%(NW/NQ);
3172
3173    const int ib0 = sgitg*NQ + ix;
3174
3175    float yl[NQ];
3176
3177    device const float * yb = y + ib0*QK8_0 + il*NQ;
3178
3179    // each thread in a SIMD group deals with NQ quants at a time
3180    for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3181        for (short i = 0; i < NQ; ++i) {
3182            yl[i] = yb[i];
3183        }
3184
3185        for (short row = 0; row < NR0; row++) {
3186            device const int8_t * qs = ax[row][ib].qs + il*NQ;
3187
3188            float sumq = 0.f;
3189            FOR_UNROLL (short i = 0; i < NQ; ++i) {
3190                sumq += qs[i] * yl[i];
3191            }
3192
3193            sumf[row] += sumq*ax[row][ib].d;
3194        }
3195
3196        yb += NSG*NQ*QK8_0;
3197    }
3198
3199    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3200
3201    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3202}
3203
3204[[host_name("kernel_mul_mv_q8_0_f32")]]
3205kernel void kernel_mul_mv_q8_0_f32(
3206        constant ggml_metal_kargs_mul_mv & args,
3207        device const char * src0,
3208        device const char * src1,
3209        device       char * dst,
3210        threadgroup  char * shmem [[threadgroup(0)]],
3211        uint3  tgpig[[threadgroup_position_in_grid]],
3212        ushort tiisg[[thread_index_in_simdgroup]],
3213        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3214    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3215}
3216
3217// mat-vec kernel processing in chunks of float4
3218// chpb - chunks per quantization block
3219template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
3220void kernel_mul_mv_ext_q4_f32_impl(
3221        constant ggml_metal_kargs_mul_mv_ext & args,
3222        device const char * src0,
3223        device const char * src1,
3224        device       char * dst,
3225        uint3   tgpig[[threadgroup_position_in_grid]],
3226        ushort  tiisg[[thread_index_in_simdgroup]],
3227        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
3228    const short NSG   = FC_mul_mv_nsg;
3229    const short nxpsg = FC_mul_mv_nxpsg;
3230
3231    const short chpt = 4; // chunks per thread
3232
3233  //const short nxpsg = (32);
3234    const short nypsg = (32/nxpsg);
3235
3236    const short tx = tiisg%nxpsg;
3237    const short ty = tiisg/nxpsg;
3238
3239    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
3240    const int i11 = tgpig.y*r1ptg;
3241    const int i1m = tgpig.z;
3242
3243    const int i12 = i1m%args.ne12;
3244    const int i13 = i1m/args.ne12;
3245
3246    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3247    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3248
3249    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
3250
3251    device const float4 * y4[r1ptg];
3252
3253    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
3254        y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
3255    }
3256
3257    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
3258
3259    short cch = tx%chpb; // current chunk index
3260
3261    for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
3262        float4 lx[chpt];
3263
3264#pragma unroll(chpt)
3265        for (short ch = 0; ch < chpt; ++ch) {
3266            deq_t4(xq, cch, lx[ch]);
3267
3268            cch += nxpsg;
3269            if (cch >= chpb) {
3270                xq  += cch/chpb;
3271                cch %= chpb;
3272            }
3273        }
3274
3275#pragma unroll(chpt)
3276        for (short ch = 0; ch < chpt; ++ch) {
3277#pragma unroll(r1ptg)
3278            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3279                sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
3280            }
3281        }
3282
3283#pragma unroll(r1ptg)
3284        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3285            y4[ir1] += chpt*nxpsg;
3286        }
3287    }
3288
3289    // reduce only the threads in each row
3290    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3291        if (nxpsg >= 32) {
3292            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
3293        }
3294        if (nxpsg >= 16) {
3295            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
3296        }
3297        if (nxpsg >= 8) {
3298            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
3299        }
3300        if (nxpsg >= 4) {
3301            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
3302        }
3303        if (nxpsg >= 2) {
3304            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
3305        }
3306
3307        //sumf[ir1] = simd_sum(sumf[ir1]);
3308    }
3309
3310    if (tx == 0) {
3311        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
3312            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
3313
3314            if (i01 < args.ne01) {
3315                dst_f32[i01] = sumf[ir1];
3316            }
3317        }
3318    }
3319}
3320
3321// mat-vec kernel processing in chunks of float4x4
3322template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
3323void kernel_mul_mv_ext_q4x4_f32_impl(
3324        constant ggml_metal_kargs_mul_mv_ext & args,
3325        device const char * src0,
3326        device const char * src1,
3327        device       char * dst,
3328        uint3   tgpig[[threadgroup_position_in_grid]],
3329        ushort  tiisg[[thread_index_in_simdgroup]],
3330        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
3331    const short NSG   = FC_mul_mv_nsg;
3332    const short nxpsg = FC_mul_mv_nxpsg;
3333
3334    const short chpt = 1;
3335
3336  //const short nxpsg = (32);
3337    const short nypsg = (32/nxpsg);
3338
3339    const short tx = tiisg%nxpsg;
3340    const short ty = tiisg/nxpsg;
3341
3342    const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
3343    const int i11 = tgpig.y*r1ptg;
3344    const int i1m = tgpig.z;
3345
3346    const int i12 = i1m%args.ne12;
3347    const int i13 = i1m/args.ne12;
3348
3349    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3350    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3351
3352    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
3353
3354    device const float4x4 * y4x4[r1ptg];
3355
3356    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
3357        y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
3358    }
3359
3360    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
3361
3362    short cch = tx%chpb;
3363
3364    for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
3365        float4x4 lx[chpt];
3366
3367#pragma unroll(chpt)
3368        for (short ch = 0; ch < chpt; ++ch) {
3369            deq_t4x4(xq, cch, lx[ch]);
3370
3371            cch += nxpsg;
3372            if (cch >= chpb) {
3373                xq  += cch/chpb;
3374                cch %= chpb;
3375            }
3376        }
3377
3378#pragma unroll(chpt)
3379        for (short ch = 0; ch < chpt; ++ch) {
3380#pragma unroll(r1ptg)
3381            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3382                sumf[ir1] +=
3383                    dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
3384                    dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
3385                    dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
3386                    dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
3387
3388            }
3389        }
3390
3391#pragma unroll(r1ptg)
3392        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3393            y4x4[ir1] += chpt*nxpsg;
3394        }
3395    }
3396
3397    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3398        if (nxpsg >= 32) {
3399            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
3400        }
3401        if (nxpsg >= 16) {
3402            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
3403        }
3404        if (nxpsg >= 8) {
3405            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
3406        }
3407        if (nxpsg >= 4) {
3408            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
3409        }
3410        if (nxpsg >= 2) {
3411            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
3412        }
3413
3414        //sumf[ir1] = simd_sum(sumf[ir1]);
3415    }
3416
3417    if (tx == 0) {
3418        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
3419            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
3420
3421            if (i01 < args.ne01) {
3422                dst_f32[i01] = sumf[ir1];
3423            }
3424        }
3425    }
3426}
3427
3428// dispatchers needed for compile-time nxpsg
3429// epb - elements per quantization block
3430template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
3431kernel void kernel_mul_mv_ext_q4_f32_disp(
3432        constant ggml_metal_kargs_mul_mv_ext & args,
3433        device const char * src0,
3434        device const char * src1,
3435        device       char * dst,
3436        uint3   tgpig[[threadgroup_position_in_grid]],
3437        ushort  tiisg[[thread_index_in_simdgroup]],
3438        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
3439    kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
3440}
3441
3442template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
3443kernel void kernel_mul_mv_ext_q4x4_f32_disp(
3444        constant ggml_metal_kargs_mul_mv_ext & args,
3445        device const char * src0,
3446        device const char * src1,
3447        device       char * dst,
3448        uint3   tgpig[[threadgroup_position_in_grid]],
3449        ushort  tiisg[[thread_index_in_simdgroup]],
3450        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
3451    kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
3452}
3453
3454typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
3455typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;
3456
3457template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4,       4,  dequantize_f32_t4>;
3458template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4,       4,  dequantize_f32_t4>;
3459template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4,       4,  dequantize_f32_t4>;
3460template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4,       4,  dequantize_f32_t4>;
3461
3462template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;
3463template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;
3464template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
3465template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;
3466
3467template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;
3468template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;
3469template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;
3470template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0,   32, dequantize_q4_0_t4>;
3471
3472template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1,   32, dequantize_q4_1_t4>;
3473template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1,   32, dequantize_q4_1_t4>;
3474template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1,   32, dequantize_q4_1_t4>;
3475template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1,   32, dequantize_q4_1_t4>;
3476
3477template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0,   32, dequantize_q5_0_t4>;
3478template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0,   32, dequantize_q5_0_t4>;
3479template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0,   32, dequantize_q5_0_t4>;
3480template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0,   32, dequantize_q5_0_t4>;
3481
3482template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1,   32, dequantize_q5_1_t4>;
3483template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1,   32, dequantize_q5_1_t4>;
3484template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1,   32, dequantize_q5_1_t4>;
3485template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1,   32, dequantize_q5_1_t4>;
3486
3487template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0,   32, dequantize_q8_0_t4>;
3488template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0,   32, dequantize_q8_0_t4>;
3489template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
3490template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
3491
3492template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4,  32, dequantize_mxfp4_t4>;
3493template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4,  32, dequantize_mxfp4_t4>;
3494template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4,  32, dequantize_mxfp4_t4>;
3495template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4,  32, dequantize_mxfp4_t4>;
3496
3497template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3498template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3499template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3500template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3501
3502template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
3503template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
3504template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
3505template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
3506
3507template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
3508template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
3509template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
3510template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
3511
3512template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
3513template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
3514template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
3515template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
3516
3517template<typename T0, typename T1, short NR0, typename args_t>
3518void kernel_mul_mv_t_t_impl(
3519        args_t args,
3520        device const char * src0,
3521        device const char * src1,
3522        device       char * dst,
3523        threadgroup  char * shmem,
3524        uint3  tgpig,
3525        ushort tiisg,
3526        ushort sgitg) {
3527    const short NSG = FC_mul_mv_nsg;
3528
3529    constexpr short NW = N_SIMDWIDTH;
3530    constexpr short NB = 32;
3531    constexpr short NF = 8;
3532
3533    const int nb = args.ne00/NB;
3534
3535    const int r0 = tgpig.x*NR0;
3536    const int r1 = tgpig.y;
3537    const int im = tgpig.z;
3538
3539    const uint i12 = im%args.ne12;
3540    const uint i13 = im/args.ne12;
3541
3542  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3543    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3544
3545  //device const T0 * x = (device const T0 *) (src0 + offset0);
3546    device const T1 * y = (device const T1 *) (src1 + offset1);
3547
3548    // pointers to src0 rows
3549    device const T0 * ax [NR0];
3550    FOR_UNROLL (short row = 0; row < NR0; ++row) {
3551        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3552
3553        ax[row] = (device const T0 *) ((device char *) src0 + offset0);
3554    }
3555
3556    float sumf[NR0] = { 0.f };
3557
3558    const short ix = tiisg/(NW/NF);
3559    const short il = tiisg%(NW/NF);
3560
3561    const int ib0 = sgitg*NF + ix;
3562
3563    T1 yl[NF];
3564
3565    device const T1 * yb = y + (ib0*NB + il*NF);
3566
3567    for (int ib = ib0; ib < nb; ib += NSG*NF) {
3568        for (short i = 0; i < NF; ++i) {
3569            yl[i] = yb[i];
3570        }
3571
3572        for (short row = 0; row < NR0; row++) {
3573            device const T0 * xb = ax[row] + (ib*NB + il*NF);
3574
3575            float sumq = 0.f;
3576            FOR_UNROLL (short i = 0; i < NF; ++i) {
3577                sumq += xb[i] * yl[i];
3578            }
3579
3580            sumf[row] += sumq;
3581        }
3582
3583        yb += NSG*NF*NW;
3584    }
3585
3586    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
3587        for (short row = 0; row < NR0; row++) {
3588            sumf[row] += ax[row][i] * y[i];
3589        }
3590    }
3591
3592    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3593
3594    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3595}
3596
3597template<typename T0, typename T1, typename args_t>
3598void kernel_mul_mv_t_t_disp(
3599        args_t args,
3600        device const char * src0,
3601        device const char * src1,
3602        device       char * dst,
3603        threadgroup  char * shmem,
3604        uint3  tgpig,
3605        ushort tiisg,
3606        ushort sgitg) {
3607    switch (args.nr0) {
3608      //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3609        case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3610      //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3611      //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3612    }
3613}
3614
3615template<typename T0, typename T1>
3616kernel void kernel_mul_mv_t_t(
3617        constant ggml_metal_kargs_mul_mv & args,
3618        device const char * src0,
3619        device const char * src1,
3620        device       char * dst,
3621        threadgroup  char * shmem [[threadgroup(0)]],
3622        uint3  tgpig[[threadgroup_position_in_grid]],
3623        ushort tiisg[[thread_index_in_simdgroup]],
3624        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3625    kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3626}
3627
3628typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
3629
3630template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
3631template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float>;
3632template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half>;
3633#if defined(GGML_METAL_HAS_BF16)
3634template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
3635template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
3636#endif
3637
3638template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
3639void kernel_mul_mv_t_t_4_impl(
3640        args_t args,
3641        device const char * src0,
3642        device const char * src1,
3643        device       char * dst,
3644        threadgroup  char * shmem,
3645        uint3  tgpig,
3646        ushort tiisg,
3647        ushort sgitg) {
3648    const short NSG = FC_mul_mv_nsg;
3649
3650    constexpr short NW = N_SIMDWIDTH;
3651    constexpr short NB  = 32;
3652    constexpr short NF  = 16;
3653    constexpr short NF4 = NF/4;
3654
3655    const int nb = args.ne00/NB;
3656
3657    const int r0 = tgpig.x*NR0;
3658    const int r1 = tgpig.y;
3659    const int im = tgpig.z;
3660
3661    const uint i12 = im%args.ne12;
3662    const uint i13 = im/args.ne12;
3663
3664  //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3665    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
3666
3667    device const T1  * y  = (device const T1  *) (src1 + offset1);
3668    device const T14 * y4 = (device const T14 *) (src1 + offset1);
3669
3670    // pointers to src0 rows
3671    device const T0  * ax [NR0];
3672    device const T04 * ax4[NR0];
3673    FOR_UNROLL (short row = 0; row < NR0; ++row) {
3674        const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3675
3676        ax [row] = (device const T0  *) ((device char *) src0 + offset0);
3677        ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
3678    }
3679
3680    float sumf[NR0] = { 0.f };
3681
3682    const short ix = tiisg/(NW/NF);
3683    const short il = tiisg%(NW/NF);
3684
3685    const int ib0 = sgitg*NF + ix;
3686
3687    T14 yl4[NF4];
3688
3689    device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
3690
3691    for (int ib = ib0; ib < nb; ib += NSG*NF) {
3692        for (short i = 0; i < NF4; ++i) {
3693            yl4[i] = yb4[i];
3694        }
3695
3696        for (short row = 0; row < NR0; row++) {
3697            device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
3698
3699            float sumq = 0.f;
3700            FOR_UNROLL (short i = 0; i < NF4; ++i) {
3701                sumq += dot(float4(xb4[i]), float4(yl4[i]));
3702            }
3703
3704            sumf[row] += sumq;
3705        }
3706
3707        yb4 += NSG*NF*NW/4;
3708    }
3709
3710    for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
3711        for (short row = 0; row < NR0; row++) {
3712            sumf[row] += ax[row][i] * y[i];
3713        }
3714    }
3715
3716    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3717
3718    helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3719}
3720
3721template<typename T0, typename T04, typename T1, typename T14, typename args_t>
3722void kernel_mul_mv_t_t_4_disp(
3723        args_t args,
3724        device const char * src0,
3725        device const char * src1,
3726        device       char * dst,
3727        threadgroup  char * shmem,
3728        uint3  tgpig,
3729        ushort tiisg,
3730        ushort sgitg) {
3731    switch (args.nr0) {
3732      //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3733        case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3734      //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3735      //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3736    };
3737}
3738
3739template<typename T0, typename T04, typename T1, typename T14>
3740kernel void kernel_mul_mv_t_t_4(
3741        constant ggml_metal_kargs_mul_mv & args,
3742        device const char * src0,
3743        device const char * src1,
3744        device       char * dst,
3745        threadgroup  char * shmem [[threadgroup(0)]],
3746        uint3  tgpig[[threadgroup_position_in_grid]],
3747        ushort tiisg[[thread_index_in_simdgroup]],
3748        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3749    kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3750}
3751
3752typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
3753
3754template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
3755template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4>;
3756template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4>;
3757#if defined(GGML_METAL_HAS_BF16)
3758template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4>;
3759template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
3760#endif
3761
3762template<typename T0, typename T1, typename args_t>
3763void kernel_mul_mv_t_t_short_impl(
3764        args_t args,
3765        device const char * src0,
3766        device const char * src1,
3767        device       char * dst,
3768        uint3  tgpig,
3769        ushort tiisg) {
3770    const int r0 = tgpig.x*32 + tiisg;
3771    const int r1 = tgpig.y;
3772    const int im = tgpig.z;
3773
3774    if (r0 >= args.ne01) {
3775        return;
3776    }
3777
3778    const uint i12 = im%args.ne12;
3779    const uint i13 = im/args.ne12;
3780
3781    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3782
3783    device const T0 * x = (device const T0 *) (src0 + offset0);
3784
3785    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
3786
3787    const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
3788
3789    device const T1 * y = (device const T1 *) (src1 + offset1);
3790
3791    float res = 0.0f;
3792
3793    for (int i = 0; i < args.ne00; ++i) {
3794        res += (float) x[i] * (float) y[i];
3795    }
3796
3797    dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
3798}
3799
3800template<typename T0, typename T1>
3801kernel void kernel_mul_mv_t_t_short(
3802        constant ggml_metal_kargs_mul_mv & args,
3803        device const char * src0,
3804        device const char * src1,
3805        device       char * dst,
3806        uint3  tgpig[[threadgroup_position_in_grid]],
3807        ushort tiisg[[thread_index_in_simdgroup]]) {
3808    kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
3809        args,
3810        src0,
3811        src1,
3812        dst,
3813        tgpig,
3814        tiisg);
3815}
3816
3817typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
3818
3819template [[host_name("kernel_mul_mv_f32_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
3820template [[host_name("kernel_mul_mv_f16_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  float>;
3821template [[host_name("kernel_mul_mv_f16_f16_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half,  half>;
3822#if defined(GGML_METAL_HAS_BF16)
3823template [[host_name("kernel_mul_mv_bf16_f32_short")]]  kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
3824template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
3825#endif
3826
3827constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
3828
3829static float rope_yarn_ramp(const float low, const float high, const int i0) {
3830    const float y = (i0 / 2 - low) / max(0.001f, high - low);
3831    return 1.0f - min(1.0f, max(0.0f, y));
3832}
3833
3834// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
3835// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
3836static void rope_yarn(
3837    float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
3838    thread float * cos_theta, thread float * sin_theta) {
3839    // Get n-d rotational scaling corrected for extrapolation
3840    float theta_interp = freq_scale * theta_extrap;
3841    float theta = theta_interp;
3842    if (ext_factor != 0.0f) {
3843        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
3844        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
3845
3846        // Get n-d magnitude scaling corrected for interpolation
3847        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
3848    }
3849    *cos_theta = cos(theta) * mscale;
3850    *sin_theta = sin(theta) * mscale;
3851}
3852
3853// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
3854// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
3855static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
3856    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
3857}
3858
3859static void rope_yarn_corr_dims(
3860    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
3861) {
3862    // start and end correction dims
3863    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
3864    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
3865}
3866
3867template<typename T>
3868kernel void kernel_rope_norm(
3869        constant ggml_metal_kargs_rope & args,
3870        device const char * src0,
3871        device const char * src1,
3872        device const char * src2,
3873        device       char * dst,
3874        ushort  tiitg[[thread_index_in_threadgroup]],
3875        ushort3 tptg [[threads_per_threadgroup]],
3876        uint3   tgpig[[threadgroup_position_in_grid]]) {
3877    const int i3 = tgpig[2];
3878    const int i2 = tgpig[1];
3879    const int i1 = tgpig[0];
3880
3881    float corr_dims[2];
3882    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3883
3884    device const int32_t * pos = (device const int32_t *) src1;
3885
3886    const float theta_base = (float) pos[i2];
3887    const float inv_ndims = -1.f/args.n_dims;
3888
3889    float cos_theta;
3890    float sin_theta;
3891
3892    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3893        if (i0 < args.n_dims) {
3894            const int ic = i0/2;
3895
3896            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3897
3898            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3899
3900            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3901
3902            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3903            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
3904
3905            const float x0 = src[0];
3906            const float x1 = src[1];
3907
3908            dst_data[0] = x0*cos_theta - x1*sin_theta;
3909            dst_data[1] = x0*sin_theta + x1*cos_theta;
3910        } else {
3911            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3912            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
3913
3914            dst_data[0] = src[0];
3915            dst_data[1] = src[1];
3916        }
3917    }
3918}
3919
3920template<typename T>
3921kernel void kernel_rope_neox(
3922        constant ggml_metal_kargs_rope & args,
3923        device const char * src0,
3924        device const char * src1,
3925        device const char * src2,
3926        device       char * dst,
3927        ushort  tiitg[[thread_index_in_threadgroup]],
3928        ushort3 tptg [[threads_per_threadgroup]],
3929        uint3   tgpig[[threadgroup_position_in_grid]]) {
3930    const int i3 = tgpig[2];
3931    const int i2 = tgpig[1];
3932    const int i1 = tgpig[0];
3933
3934    float corr_dims[2];
3935    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3936
3937    device const int32_t * pos = (device const int32_t *) src1;
3938
3939    const float theta_base = (float) pos[i2];
3940    const float inv_ndims = -1.f/args.n_dims;
3941
3942    float cos_theta;
3943    float sin_theta;
3944
3945    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3946        if (i0 < args.n_dims) {
3947            const int ic = i0/2;
3948
3949            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3950
3951            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3952
3953            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3954
3955            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
3956            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
3957
3958            const float x0 = src[0];
3959            const float x1 = src[args.n_dims/2];
3960
3961            dst_data[0]             = x0*cos_theta - x1*sin_theta;
3962            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
3963        } else {
3964            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3965            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
3966
3967            dst_data[0] = src[0];
3968            dst_data[1] = src[1];
3969        }
3970    }
3971}
3972
3973template<typename T>
3974kernel void kernel_rope_multi(
3975        constant ggml_metal_kargs_rope & args,
3976        device const char * src0,
3977        device const char * src1,
3978        device const char * src2,
3979        device       char * dst,
3980        ushort  tiitg[[thread_index_in_threadgroup]],
3981        ushort3 tptg [[threads_per_threadgroup]],
3982        uint3   tgpig[[threadgroup_position_in_grid]]) {
3983    const int i3 = tgpig[2];
3984    const int i2 = tgpig[1];
3985    const int i1 = tgpig[0];
3986
3987    float corr_dims[2];
3988    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3989
3990    device const int32_t * pos = (device const int32_t *) src1;
3991
3992    const float inv_ndims = -1.f/args.n_dims;
3993
3994    float cos_theta;
3995    float sin_theta;
3996
3997    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3998        if (i0 < args.n_dims) {
3999            const int ic = i0/2;
4000
4001            // mrope theta calculations
4002            // note: the rest is the same as kernel_rope_neox
4003            const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
4004            const int sec_w01   = args.sect_0 + args.sect_1;               // end of section 1
4005            const int sec_w012  = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
4006            const int sector    = ic % sect_dims;
4007
4008            float theta_base;
4009            if (FC_rope_is_imrope) {
4010                if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
4011                    theta_base = (float) pos[i2 + args.ne02 * 1];
4012                } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
4013                    theta_base = (float) pos[i2 + args.ne02 * 2];
4014                } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
4015                    theta_base = (float) pos[i2 + args.ne02 * 0];
4016                } else { // e
4017                    theta_base = (float) pos[i2 + args.ne02 * 3];
4018                }
4019            } else {
4020                if (sector < args.sect_0) {
4021                    theta_base = (float) pos[i2];
4022                } else if (sector < sec_w01) {
4023                    theta_base = (float) pos[i2 + args.ne02 * 1];
4024                } else if (sector < sec_w012) {
4025                    theta_base = (float) pos[i2 + args.ne02 * 2];
4026                } else {
4027                    theta_base = (float) pos[i2 + args.ne02 * 3];
4028                }
4029            }
4030            // end of mrope
4031
4032            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
4033
4034            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
4035
4036            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
4037
4038            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
4039            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
4040
4041            const float x0 = src[0];
4042            const float x1 = src[args.n_dims/2];
4043
4044            dst_data[0]             = x0*cos_theta - x1*sin_theta;
4045            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
4046        } else {
4047            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
4048            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
4049
4050            dst_data[0] = src[0];
4051            dst_data[1] = src[1];
4052        }
4053    }
4054}
4055
4056template<typename T>
4057kernel void kernel_rope_vision(
4058        constant ggml_metal_kargs_rope & args,
4059        device const char * src0,
4060        device const char * src1,
4061        device const char * src2,
4062        device       char * dst,
4063        ushort  tiitg[[thread_index_in_threadgroup]],
4064        ushort3 tptg [[threads_per_threadgroup]],
4065        uint3   tgpig[[threadgroup_position_in_grid]]) {
4066    const int i3 = tgpig[2];
4067    const int i2 = tgpig[1];
4068    const int i1 = tgpig[0];
4069
4070    float corr_dims[2];
4071    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
4072
4073    device const int32_t * pos = (device const int32_t *) src1;
4074
4075    const float inv_ndims = -1.f/args.n_dims;
4076
4077    float cos_theta;
4078    float sin_theta;
4079
4080    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
4081        if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
4082            const int ic = i0/2;
4083
4084            // mrope theta calculations (only support 2 dimensions)
4085            const int sect_dims = args.sect_0 + args.sect_1;
4086            const int sector    = ic % sect_dims;
4087
4088            float p;
4089            float theta_base;
4090            if (sector < args.sect_1) {
4091                p = (float) sector;
4092                theta_base = (float) pos[i2];
4093            } else {
4094                p = (float) sector - args.sect_0;
4095                theta_base = (float) pos[i2 + args.ne02];
4096            }
4097
4098            const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
4099            // end of mrope
4100
4101            const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
4102
4103            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
4104
4105            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
4106            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
4107
4108            const float x0 = src[0];
4109            const float x1 = src[args.n_dims]; // different from kernel_rope_multi
4110
4111            dst_data[0]           = x0*cos_theta - x1*sin_theta;
4112            dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
4113        } else {
4114            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
4115            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
4116
4117            dst_data[0] = src[0];
4118            dst_data[1] = src[1];
4119        }
4120    }
4121}
4122
4123typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
4124typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
4125typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
4126typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
4127
4128template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
4129template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
4130
4131template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
4132template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
4133
4134template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
4135template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
4136
4137template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
4138template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
4139
4140typedef void (im2col_t)(
4141        constant ggml_metal_kargs_im2col & args,
4142        device const float * x,
4143        device        char * dst,
4144        uint3 tgpig[[threadgroup_position_in_grid]],
4145        uint3  tgpg[[threadgroups_per_grid]],
4146        uint3 tpitg[[thread_position_in_threadgroup]],
4147        uint3   ntg[[threads_per_threadgroup]]);
4148
4149template <typename T>
4150kernel void kernel_im2col(
4151        constant ggml_metal_kargs_im2col & args,
4152        device const float * x,
4153        device        char * dst,
4154        uint3 tgpig[[threadgroup_position_in_grid]],
4155        uint3  tgpg[[threadgroups_per_grid]],
4156        uint3 tpitg[[thread_position_in_threadgroup]],
4157        uint3   ntg[[threads_per_threadgroup]]) {
4158//    const int64_t IC = tgpg[0];
4159    const int64_t OH = tgpg[1];
4160    const int64_t OW = tgpg[2];
4161
4162    const int64_t KH = ntg[1];
4163    const int64_t KW = ntg[2];
4164
4165          int64_t in  = tpitg[0];
4166    const int64_t ikh = tpitg[1];
4167    const int64_t ikw = tpitg[2];
4168
4169    const int64_t iic = tgpig[0];
4170    const int64_t ioh = tgpig[1];
4171    const int64_t iow = tgpig[2];
4172
4173    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
4174    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
4175
4176    int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4177
4178    device T * pdst = (device T *) (dst);
4179
4180    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4181        while (in < args.N) {
4182            pdst[offset_dst] = 0.0f;
4183            offset_dst += ntg[0]*args.CHW*OH*OW;
4184
4185            in += ntg[0];
4186        }
4187    } else {
4188        int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4189
4190        while (in < args.N) {
4191            pdst[offset_dst] = x[offset_src];
4192
4193            offset_dst += ntg[0]*args.CHW*OH*OW;
4194            offset_src += ntg[0]*args.ofs0;
4195
4196            in += ntg[0];
4197        }
4198    }
4199}
4200
4201template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4202template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4203
4204// TODO: obolete -- remove
4205//typedef void (im2col_ext_t)(
4206//        constant ggml_metal_kargs_im2col & args,
4207//        device const float * x,
4208//        device        char * dst,
4209//        uint3 tgpig[[threadgroup_position_in_grid]],
4210//        uint3  tgpg[[threadgroups_per_grid]],
4211//        uint3 tpitg[[thread_position_in_threadgroup]],
4212//        uint3   ntg[[threads_per_threadgroup]]);
4213//
4214//template <typename T>
4215//kernel void kernel_im2col_ext(
4216//        constant ggml_metal_kargs_im2col & args,
4217//        device const float * x,
4218//        device        char * dst,
4219//        uint3 tgpig[[threadgroup_position_in_grid]],
4220//        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4221//        uint3 tpitg[[thread_position_in_threadgroup]],
4222//        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
4223//    const int64_t KHW = (int64_t)args.KHW;
4224//
4225//    const int64_t d   = tgpig[0] / args.CHW;
4226//    const int64_t chw = tgpig[0] % args.CHW;
4227//    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
4228//    const int64_t HW = tgpig[0] % KHW;
4229//
4230//    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4231//    if (tpitg_0 >= args.N) {
4232//        return;
4233//    }
4234//
4235//    const int64_t tpitg_1 = HW / args.KW;
4236//    const int64_t tpitg_2 = HW % args.KW;
4237//
4238//    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4239//    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4240//
4241//    const int64_t offset_dst =
4242//        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4243//        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4244//
4245//    device T * pdst = (device T *) (dst);
4246//
4247//    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4248//        pdst[offset_dst] = 0.0f;
4249//    } else {
4250//        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4251//        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4252//    }
4253//}
4254//
4255//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4256//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4257
4258template <typename TK>
4259kernel void kernel_conv_2d(
4260        constant ggml_metal_kargs_conv_2d & args,
4261        device const char * weights,
4262        device const char * src,
4263        device       char * dst,
4264        uint3   tgpig[[threadgroup_position_in_grid]],
4265        uint3    tgpg[[threadgroups_per_grid]],
4266        uint3   tpitg[[thread_position_in_threadgroup]],
4267        uint3     ntg[[threads_per_threadgroup]]) {
4268
4269    const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
4270    const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
4271    const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
4272    const uint thread_index = tg_index * threads_per_tg + local_thread;
4273    const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
4274    const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
4275
4276    for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
4277        uint64_t tmp = index;
4278
4279        const int32_t ow = tmp % args.OW; tmp /= args.OW;
4280        const int32_t oh = tmp % args.OH; tmp /= args.OH;
4281        const int32_t oc = tmp % args.OC; tmp /= args.OC;
4282        const int32_t  n = tmp;
4283
4284        float acc = 0.0f;
4285
4286        const int32_t base_x = ow*args.s0 - args.p0;
4287        const int32_t base_y = oh*args.s1 - args.p1;
4288
4289        int32_t ky_start = 0;
4290        if (base_y < 0) {
4291            ky_start = (-base_y + args.d1 - 1)/args.d1;
4292        }
4293        int32_t ky_end = args.KH;
4294        const int32_t y_max = args.IH - 1 - base_y;
4295        if (y_max < 0) {
4296            ky_end = ky_start;
4297        } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
4298            ky_end = min(ky_end, y_max/args.d1 + 1);
4299        }
4300
4301        int32_t kx_start = 0;
4302        if (base_x < 0) {
4303            kx_start = (-base_x + args.d0 - 1)/args.d0;
4304        }
4305        int32_t kx_end = args.KW;
4306        const int32_t x_max = args.IW - 1 - base_x;
4307        if (x_max < 0) {
4308            kx_end = kx_start;
4309        } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
4310            kx_end = min(kx_end, x_max/args.d0 + 1);
4311        }
4312
4313        if (ky_start < ky_end && kx_start < kx_end) {
4314            const uint64_t src_base_n = (uint64_t) n  * args.nb13;
4315            const uint64_t w_base_oc  = (uint64_t) oc * args.nb03;
4316
4317            for (int32_t ic = 0; ic < args.IC; ++ic) {
4318                const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
4319                const uint64_t w_base_ocic = w_base_oc  + (uint64_t) ic * args.nb02;
4320
4321                for (int32_t ky = ky_start; ky < ky_end; ++ky) {
4322                    const int32_t iy = base_y + ky*args.d1;
4323                    const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
4324                    const uint64_t w_base_row   = w_base_ocic + (uint64_t) ky * args.nb01;
4325
4326                    for (int32_t kx = kx_start; kx < kx_end; ++kx) {
4327                        const int32_t ix = base_x + kx*args.d0;
4328                        const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
4329                        const uint64_t w_offs   = w_base_row   + (uint64_t) kx * args.nb00;
4330
4331                        const float x = *(device const float *)(src + src_offs);
4332                        const float w = (float) (*(device const TK *)(weights + w_offs));
4333
4334                        acc += x * w;
4335                    }
4336                }
4337            }
4338        }
4339
4340        const uint64_t dst_offs =
4341            (uint64_t) n  * args.nb3 +
4342            (uint64_t) oc * args.nb2 +
4343            (uint64_t) oh * args.nb1 +
4344            (uint64_t) ow * args.nb0;
4345
4346        *(device float *)(dst + dst_offs) = acc;
4347    }
4348}
4349
4350template [[host_name("kernel_conv_2d_f32_f32")]]
4351kernel void kernel_conv_2d<float>(
4352        constant ggml_metal_kargs_conv_2d & args,
4353        device const char * weights,
4354        device const char * src,
4355        device       char * dst,
4356        uint3   tgpig[[threadgroup_position_in_grid]],
4357        uint3    tgpg[[threadgroups_per_grid]],
4358        uint3   tpitg[[thread_position_in_threadgroup]],
4359        uint3     ntg[[threads_per_threadgroup]]);
4360
4361template [[host_name("kernel_conv_2d_f16_f32")]]
4362kernel void kernel_conv_2d<half>(
4363        constant ggml_metal_kargs_conv_2d & args,
4364        device const char * weights,
4365        device const char * src,
4366        device       char * dst,
4367        uint3   tgpig[[threadgroup_position_in_grid]],
4368        uint3    tgpg[[threadgroups_per_grid]],
4369        uint3   tpitg[[thread_position_in_threadgroup]],
4370        uint3     ntg[[threads_per_threadgroup]]);
4371
4372typedef void (conv_transpose_1d_t)(
4373        constant ggml_metal_kargs_conv_transpose_1d & args,
4374        device const float * src0,
4375        device const float * src1,
4376        device        char * dst,
4377        uint3   tgpig[[threadgroup_position_in_grid]],
4378        uint3    tgpg[[threadgroups_per_grid]]);
4379
4380template <typename T>
4381kernel void kernel_conv_transpose_1d(
4382        constant ggml_metal_kargs_conv_transpose_1d & args,
4383        device const     T * src0,
4384        device const float * src1,
4385        device        char * dst,
4386        uint3   tgpig[[threadgroup_position_in_grid]],
4387        uint3   tgpg[[threadgroups_per_grid]]) {
4388
4389    float v = 0.0f;
4390
4391    for (int64_t c = 0; c < args.IC; c++) {
4392        const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
4393        const int32_t input_offset = c * args.IL;
4394
4395        for (int64_t i = 0; i < args.IL; i++) {
4396            if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
4397                v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
4398            }
4399        }
4400    }
4401
4402    device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
4403
4404    dst_ptr[0] = v;
4405}
4406
4407template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
4408kernel void kernel_conv_transpose_1d<float>(
4409    constant ggml_metal_kargs_conv_transpose_1d & args,
4410    device const float * src0,
4411    device const float * src1,
4412    device        char * dst,
4413    uint3   tgpig[[threadgroup_position_in_grid]],
4414    uint3    tgpg[[threadgroups_per_grid]]);
4415
4416template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
4417kernel void kernel_conv_transpose_1d<half>(
4418    constant ggml_metal_kargs_conv_transpose_1d & args,
4419    device const half  * src0,
4420    device const float * src1,
4421    device        char * dst,
4422    uint3   tgpig[[threadgroup_position_in_grid]],
4423    uint3    tgpg[[threadgroups_per_grid]]);
4424
4425
4426typedef void (conv_transpose_2d_t)(
4427        constant ggml_metal_kargs_conv_transpose_2d & args,
4428        device const float * src0,
4429        device const float * src1,
4430        device        char * dst,
4431        uint3   tgpig[[threadgroup_position_in_grid]],
4432        uint3    tgpg[[threadgroups_per_grid]]);
4433
4434template <typename T>
4435kernel void kernel_conv_transpose_2d(
4436        constant ggml_metal_kargs_conv_transpose_2d & args,
4437        device const T * src0,
4438        device const float * src1,
4439        device        char * dst,
4440        threadgroup float * shared_sum [[threadgroup(0)]],
4441        uint3   tgpig[[threadgroup_position_in_grid]],
4442        uint3   tpitg[[thread_position_in_threadgroup]],
4443        uint3     ntg[[threads_per_threadgroup]]) {
4444
4445    const int64_t out_x = tgpig[0];
4446    const int64_t out_y = tgpig[1];
4447    const int64_t out_c = tgpig[2];
4448
4449    const int64_t kw = tpitg[0];
4450    const int64_t kh = tpitg[1];
4451
4452    float v = 0.0f;
4453
4454    for (int64_t in_c = 0; in_c < args.IC; in_c++) {
4455        int64_t in_y = out_y - kh;
4456
4457        if (in_y < 0 || in_y % args.s0) continue;
4458
4459        in_y /= args.s0;
4460
4461        if (in_y >= args.IH) continue;
4462
4463        int64_t in_x = out_x - kw;
4464
4465        if (in_x < 0 || in_x % args.s0) continue;
4466
4467        in_x /= args.s0;
4468
4469        if (in_x >= args.IW) continue;
4470
4471        const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4472        const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4473
4474        v += (float)src0[kernel_idx] * src1[input_idx];
4475    }
4476
4477    const uint tid = tpitg.y * ntg.x + tpitg.x;
4478    shared_sum[tid] = v;
4479
4480    threadgroup_barrier(mem_flags::mem_threadgroup);
4481
4482    if (tid == 0) {
4483        float total = 0.0f;
4484        const uint num_threads = ntg.x * ntg.y;
4485        for (uint i = 0; i < num_threads; i++) {
4486            total += shared_sum[i];
4487        }
4488
4489        device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4490        dst_ptr[0] = total;
4491    }
4492}
4493
4494template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
4495kernel void kernel_conv_transpose_2d<float>(
4496    constant ggml_metal_kargs_conv_transpose_2d & args,
4497    device const float * src0,
4498    device const float * src1,
4499    device        char * dst,
4500    threadgroup float * shared_sum [[threadgroup(0)]],
4501    uint3   tgpig[[threadgroup_position_in_grid]],
4502    uint3   tpitg[[thread_position_in_threadgroup]],
4503    uint3     ntg[[threads_per_threadgroup]]);
4504
4505template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
4506kernel void kernel_conv_transpose_2d<half>(
4507    constant ggml_metal_kargs_conv_transpose_2d & args,
4508    device const half  * src0,
4509    device const float * src1,
4510    device        char * dst,
4511    threadgroup float * shared_sum [[threadgroup(0)]],
4512    uint3   tgpig[[threadgroup_position_in_grid]],
4513    uint3   tpitg[[thread_position_in_threadgroup]],
4514    uint3     ntg[[threads_per_threadgroup]]);
4515
4516kernel void kernel_upscale_f32(
4517    constant ggml_metal_kargs_upscale & args,
4518    device  const char * src0,
4519    device        char * dst,
4520    uint3 tgpig[[threadgroup_position_in_grid]],
4521    uint3 tpitg[[thread_position_in_threadgroup]],
4522    uint3   ntg[[threads_per_threadgroup]]) {
4523
4524    const int64_t i3 = tgpig.z;
4525    const int64_t i2 = tgpig.y;
4526    const int64_t i1 = tgpig.x;
4527
4528    const int64_t i03 = i3/args.sf3;
4529    const int64_t i02 = i2/args.sf2;
4530    const int64_t i01 = i1/args.sf1;
4531
4532    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4533        const int64_t i00 = i0/args.sf0;
4534
4535        device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4536        device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1  +  i0*args.nb0);
4537
4538        dst_ptr[0] = src0_ptr[0];
4539    }
4540}
4541
4542kernel void kernel_pad_f32(
4543    constant ggml_metal_kargs_pad & args,
4544    device  const char * src0,
4545    device        char * dst,
4546    uint3 tgpig[[threadgroup_position_in_grid]],
4547    uint3 tpitg[[thread_position_in_threadgroup]],
4548    uint3   ntg[[threads_per_threadgroup]]) {
4549
4550    const int64_t i3 = tgpig.z;
4551    const int64_t i2 = tgpig.y;
4552    const int64_t i1 = tgpig.x;
4553
4554    const int64_t i03 = i3;
4555    const int64_t i02 = i2;
4556    const int64_t i01 = i1;
4557
4558    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4559    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
4560
4561    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4562        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4563            if (i0 < args.ne00) {
4564                dst_ptr[i0] = src0_ptr[i0];
4565            } else {
4566                dst_ptr[i0] = 0.0f;
4567            }
4568        }
4569
4570        return;
4571    }
4572
4573    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4574        dst_ptr[i0] = 0.0f;
4575    }
4576}
4577
4578kernel void kernel_pad_reflect_1d_f32(
4579    constant   ggml_metal_kargs_pad_reflect_1d & args,
4580    device  const char * src0,
4581    device        char * dst,
4582    uint3 tgpig[[threadgroup_position_in_grid]],
4583    uint3  tgpg[[threadgroups_per_grid]],
4584    uint3 tpitg[[thread_position_in_threadgroup]],
4585    uint3   ntg[[threads_per_threadgroup]]) {
4586
4587    const int64_t i3 = tgpig.z;
4588    const int64_t i2 = tgpig.y;
4589    const int64_t i1 = tgpig.x;
4590
4591    const int64_t i03 = i3;
4592    const int64_t i02 = i2;
4593    const int64_t i01 = i1;
4594
4595    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4596    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
4597
4598    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4599        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4600            if (i0 < args.p0) {
4601                dst_ptr[i0] = src0_ptr[args.p0 - i0];
4602            } else if (i0 < args.ne0 - args.p1) {
4603                dst_ptr[i0] = src0_ptr[i0 - args.p0];
4604            } else {
4605                dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
4606            }
4607        }
4608    }
4609}
4610
4611kernel void kernel_arange_f32(
4612    constant   ggml_metal_kargs_arange & args,
4613    device        char * dst,
4614    uint3 tgpig[[threadgroup_position_in_grid]],
4615    uint3 tpitg[[thread_position_in_threadgroup]],
4616    uint3   ntg[[threads_per_threadgroup]]) {
4617
4618    device float * dst_ptr = (device float *) dst;
4619
4620    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4621        dst_ptr[i0] = args.start + args.step * i0;
4622    }
4623}
4624
4625kernel void kernel_timestep_embedding_f32(
4626    constant  ggml_metal_kargs_timestep_embedding & args,
4627    device  const char * src0,
4628    device        char * dst,
4629    uint3 tgpig[[threadgroup_position_in_grid]],
4630    uint3 tpitg[[thread_position_in_threadgroup]],
4631    uint3   ntg[[threads_per_threadgroup]]) {
4632
4633    int i = tgpig.x;
4634    device float * embed_data = (device float *)(dst + i*args.nb1);
4635
4636    int half_ = args.dim / 2;
4637    for (int j = tpitg.x; j < half_; j += ntg.x) {
4638        float timestep = ((device float *)src0)[i];
4639        float freq = (float)exp(-log((float)args.max_period) * j / half_);
4640        float arg = timestep * freq;
4641        embed_data[j        ] = cos(arg);
4642        embed_data[j + half_] = sin(arg);
4643    }
4644
4645    if (args.dim % 2 != 0 && tpitg.x == 0) {
4646        embed_data[2 * half_] = 0.f;
4647    }
4648}
4649
4650// bitonic sort implementation following the CUDA kernels as reference
4651typedef void (argsort_t)(
4652        constant   ggml_metal_kargs_argsort & args,
4653        device   const char * src0,
4654        device      int32_t * dst,
4655        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4656        uint3   tgpig[[threadgroup_position_in_grid]],
4657        ushort3 tpitg[[thread_position_in_threadgroup]],
4658        ushort3   ntg[[threads_per_threadgroup]]);
4659
4660template<ggml_sort_order order>
4661kernel void kernel_argsort_f32_i32(
4662        constant   ggml_metal_kargs_argsort & args,
4663        device   const char * src0,
4664        device      int32_t * dst,
4665        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4666        uint3   tgpig[[threadgroup_position_in_grid]],
4667        ushort3 tpitg[[thread_position_in_threadgroup]],
4668        ushort3   ntg[[threads_per_threadgroup]]) {
4669    // bitonic sort
4670    const int col = tpitg[0];
4671    const int ib  = tgpig[0] / args.ne01;
4672
4673    const int i00 = ib*ntg.x;
4674    const int i01 = tgpig[0] % args.ne01;
4675    const int i02 = tgpig[1];
4676    const int i03 = tgpig[2];
4677
4678    device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
4679
4680    // initialize indices
4681    shmem_i32[col] = i00 + col;
4682
4683    threadgroup_barrier(mem_flags::mem_threadgroup);
4684
4685    for (int k = 2; k <= ntg.x; k *= 2) {
4686        for (int j = k / 2; j > 0; j /= 2) {
4687            int ixj = col ^ j;
4688            if (ixj > col) {
4689                if ((col & k) == 0) {
4690                    if (shmem_i32[col] >= args.ne00 ||
4691                       (shmem_i32[ixj] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4692                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
4693                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
4694                    ) {
4695                        SWAP(shmem_i32[col], shmem_i32[ixj]);
4696                    }
4697                } else {
4698                    if (shmem_i32[ixj] >= args.ne00 ||
4699                       (shmem_i32[col] <  args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4700                            src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
4701                            src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
4702                    ) {
4703                        SWAP(shmem_i32[col], shmem_i32[ixj]);
4704                    }
4705                }
4706            }
4707
4708            threadgroup_barrier(mem_flags::mem_threadgroup);
4709        }
4710    }
4711
4712    const int64_t i0 = ib*args.top_k;
4713
4714    // copy the result to dst without the padding
4715    if (i0 + col < args.ne0 && col < args.top_k) {
4716        dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
4717
4718        dst[col] = shmem_i32[col];
4719    }
4720}
4721
4722template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
4723template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
4724
4725typedef void (argsort_merge_t)(
4726        constant   ggml_metal_kargs_argsort_merge & args,
4727        device const char    * src0,
4728        device const int32_t * tmp,
4729        device       int32_t * dst,
4730        uint3   tgpig[[threadgroup_position_in_grid]],
4731        ushort3 tpitg[[thread_position_in_threadgroup]],
4732        ushort3   ntg[[threads_per_threadgroup]]);
4733
4734template<ggml_sort_order order>
4735kernel void kernel_argsort_merge_f32_i32(
4736        constant   ggml_metal_kargs_argsort_merge & args,
4737        device const char    * src0,
4738        device const int32_t * tmp,
4739        device       int32_t * dst,
4740        uint3   tgpig[[threadgroup_position_in_grid]],
4741        ushort3 tpitg[[thread_position_in_threadgroup]],
4742        ushort3   ntg[[threads_per_threadgroup]]) {
4743
4744    const int im  = tgpig[0] / args.ne01;
4745    const int i01 = tgpig[0] % args.ne01;
4746    const int i02 = tgpig[1];
4747    const int i03 = tgpig[2];
4748
4749    const int start = im * (2 * args.len);
4750
4751    const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
4752    const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
4753
4754    const int total = len0 + len1;
4755
4756    device const int32_t * tmp0 = tmp + start
4757        + i01*args.ne0
4758        + i02*args.ne0*args.ne01
4759        + i03*args.ne0*args.ne01*args.ne02;
4760
4761    device const int32_t * tmp1 = tmp0 + args.len;
4762
4763    dst += start
4764        + i01*args.top_k
4765        + i02*args.top_k*args.ne01
4766        + i03*args.top_k*args.ne01*args.ne02;
4767
4768    device const float * src0_row = (device const float *)(src0
4769        + args.nb01*i01
4770        + args.nb02*i02
4771        + args.nb03*i03);
4772
4773    if (total == 0) {
4774        return;
4775    }
4776
4777    const int chunk = (total + ntg.x - 1) / ntg.x;
4778
4779    const int k0 = tpitg.x * chunk;
4780    const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
4781
4782    if (k0 >= args.top_k) {
4783        return;
4784    }
4785
4786    if (k0 >= total) {
4787        return;
4788    }
4789
4790    int low  = k0 > len1 ? k0 - len1 : 0;
4791    int high = MIN(k0, len0);
4792
4793    // binary-search partition (i, j) such that i + j = k
4794    while (low < high) {
4795        const int mid = (low + high) >> 1;
4796
4797        const int32_t idx0 = tmp0[mid];
4798        const int32_t idx1 = tmp1[k0 - mid - 1];
4799
4800        const float val0 = src0_row[idx0];
4801        const float val1 = src0_row[idx1];
4802
4803        bool take_left;
4804        if (order == GGML_SORT_ORDER_ASC) {
4805            take_left = (val0 <= val1);
4806        } else {
4807            take_left = (val0 >= val1);
4808        }
4809
4810        if (take_left) {
4811            low = mid + 1;
4812        } else {
4813            high = mid;
4814        }
4815    }
4816
4817    int i = low;
4818    int j = k0 - i;
4819
4820    // keep the merge fronts into registers
4821    int32_t idx0 = 0;
4822    float   val0 = 0.0f;
4823    if (i < len0) {
4824        idx0 = tmp0[i];
4825        val0 = src0_row[idx0];
4826    }
4827
4828    int32_t idx1 = 0;
4829    float   val1 = 0.0f;
4830    if (j < len1) {
4831        idx1 = tmp1[j];
4832        val1 = src0_row[idx1];
4833    }
4834
4835    for (int k = k0; k < k1; ++k) {
4836        int32_t out_idx;
4837
4838        if (i >= len0) {
4839            while (k < k1) {
4840                dst[k++] = tmp1[j++];
4841            }
4842            break;
4843        } else if (j >= len1) {
4844            while (k < k1) {
4845                dst[k++] = tmp0[i++];
4846            }
4847            break;
4848        } else {
4849            bool take_left;
4850
4851            if (order == GGML_SORT_ORDER_ASC) {
4852                take_left = (val0 <= val1);
4853            } else {
4854                take_left = (val0 >= val1);
4855            }
4856
4857            if (take_left) {
4858                out_idx = idx0;
4859                ++i;
4860                if (i < len0) {
4861                    idx0 = tmp0[i];
4862                    val0 = src0_row[idx0];
4863                }
4864            } else {
4865                out_idx = idx1;
4866                ++j;
4867                if (j < len1) {
4868                    idx1 = tmp1[j];
4869                    val1 = src0_row[idx1];
4870                }
4871            }
4872        }
4873
4874        dst[k] = out_idx;
4875    }
4876}
4877
4878template [[host_name("kernel_argsort_merge_f32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
4879template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
4880
4881constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
4882
4883constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
4884
4885// pad the last chunk of C elements of k and v into a an extra pad buffer
4886kernel void kernel_flash_attn_ext_pad(
4887        constant ggml_metal_kargs_flash_attn_ext_pad & args,
4888        device const char * k,
4889        device const char * v,
4890        device const char * mask,
4891        device       char * dst,
4892        uint3   tgpig[[threadgroup_position_in_grid]],
4893        ushort  tiitg[[thread_index_in_threadgroup]],
4894        ushort3   ntg[[threads_per_threadgroup]]) {
4895    const int32_t C = FC_flash_attn_ext_pad_ncpsg;
4896
4897    device char * k_pad    = dst;
4898    device char * v_pad    = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
4899    device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
4900
4901    const int32_t icp = args.ne11 % C;
4902    const int32_t ic0 = args.ne11 - icp;
4903
4904    const int32_t i1 = tgpig[0];
4905    const int32_t i2 = tgpig[1];
4906    const int32_t i3 = tgpig[2];
4907
4908    if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
4909        device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
4910        device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
4911
4912        device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
4913        device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
4914
4915        if (i1 >= icp) {
4916            // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
4917            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
4918                k_dst[i] = 0;
4919            }
4920            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
4921                v_dst[i] = 0;
4922            }
4923        } else {
4924            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
4925                k_dst[i] = k_src[i];
4926            }
4927            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
4928                v_dst[i] = v_src[i];
4929            }
4930        }
4931    }
4932
4933    if (FC_flash_attn_ext_pad_has_mask) {
4934        if (i2 < args.ne32 && i3 < args.ne33) {
4935            for (int ib = i1; ib < args.ne31; ib += C) {
4936                device const half * mask_src = (device const half *)(mask      + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
4937                device       half * mask_dst = (device       half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
4938
4939                for (int i = tiitg; i < C; i += ntg.x) {
4940                    if (i >= icp) {
4941                        mask_dst[i] = -MAXHALF;
4942                    } else {
4943                        mask_dst[i] = mask_src[i];
4944                    }
4945                }
4946            }
4947        }
4948    }
4949}
4950
4951constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
4952constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
4953
4954// scan the blocks of the mask that are not masked
4955// 0 -     masked (i.e. full of -INF, skip)
4956// 1 - not masked (i.e. at least one element of the mask is not -INF)
4957// 2 - all zero
4958kernel void kernel_flash_attn_ext_blk(
4959        constant ggml_metal_kargs_flash_attn_ext_blk & args,
4960        device const char * mask,
4961        device       char * dst,
4962        uint3  tgpig[[threadgroup_position_in_grid]],
4963        ushort tiisg[[thread_index_in_simdgroup]]) {
4964    // block size C x Q
4965    const int32_t Q = FC_flash_attn_ext_blk_nqptg;
4966    const int32_t C = FC_flash_attn_ext_blk_ncpsg;
4967
4968    constexpr short NW  = N_SIMDWIDTH;
4969
4970    const int32_t i3 = tgpig[2]/args.ne32;
4971    const int32_t i2 = tgpig[2]%args.ne32;
4972    const int32_t i1 = tgpig[1];
4973    const int32_t i0 = tgpig[0];
4974
4975    char res = i0*C + C > args.ne30 ? 1 : 0;
4976
4977    device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
4978
4979    // detailed check of the elements of the block
4980    if ((C > NW || Q > 1) && res == 0) {
4981        half mmin =  MAXHALF;
4982        half mmax = -MAXHALF;
4983
4984        FOR_UNROLL (short j = 0; j < Q; ++j) {
4985            FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
4986                mmin = min(mmin, mask_src[ii*NW]);
4987                mmax = max(mmax, mask_src[ii*NW]);
4988            }
4989
4990            mask_src += args.nb31/2;
4991        }
4992
4993        mmin = simd_min(mmin);
4994        mmax = simd_max(mmax);
4995
4996        if (mmax > -MAXHALF) {
4997            if (mmin == 0.0 && mmax == 0.0) {
4998                res = 2;
4999            } else {
5000                res = 1;
5001            }
5002        }
5003    }
5004
5005    const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5006    const int32_t nblk0 = ((args.ne30 + C - 1)/C);
5007
5008    if (tiisg == 0) {
5009        dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
5010    }
5011}
5012
5013constant bool FC_flash_attn_ext_has_mask  [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
5014constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
5015constant bool FC_flash_attn_ext_has_bias  [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
5016constant bool FC_flash_attn_ext_has_scap  [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
5017constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
5018
5019constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
5020
5021//constant float FC_flash_attn_ext_scale         [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
5022//constant float FC_flash_attn_ext_max_bias      [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
5023//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
5024
5025constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
5026constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
5027constant int32_t FC_flash_attn_ext_nsg  [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
5028
5029// ref: https://arxiv.org/pdf/2307.08691.pdf
5030template<
5031    typename q_t,     // query types in shared memory
5032    typename q4_t,
5033    typename q8x8_t,
5034    typename k_t,     // key types in shared memory
5035    typename k4x4_t,
5036    typename k8x8_t,
5037    typename v_t,     // value types in shared memory
5038    typename v4x4_t,
5039    typename v8x8_t,
5040    typename qk_t,    // Q*K types
5041    typename qk8x8_t,
5042    typename s_t,     // soft-max types
5043    typename s2_t,
5044    typename s8x8_t,
5045    typename o_t,     // attention accumulation types
5046    typename o4_t,
5047    typename o8x8_t,
5048    typename kd4x4_t, // key type in device memory
5049    short nl_k,
5050    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
5051    typename vd4x4_t, // value type in device memory
5052    short nl_v,
5053    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5054    short DK,         // K head size
5055    short DV,         // V head size
5056    short Q,          // queries per threadgroup
5057    short C,          // cache items per threadgroup
5058    short NSG>        // number of simd groups
5059void kernel_flash_attn_ext_impl(
5060        constant ggml_metal_kargs_flash_attn_ext & args,
5061        device const char * q,
5062        device const char * k,
5063        device const char * v,
5064        device const char * mask,
5065        device const char * sinks,
5066        device const char * pad,
5067        device const char * blk,
5068        device       char * dst,
5069        threadgroup  half * shmem_f16,
5070        uint3   tgpig,
5071        ushort  tiisg,
5072        ushort  sgitg) {
5073    const ushort iq3 = tgpig[2];
5074    const ushort iq2 = tgpig[1];
5075    const ushort iq1 = tgpig[0]*Q;
5076
5077#define NS10 (FC_flash_attn_ext_ns10)
5078#define NS20 (FC_flash_attn_ext_ns20)
5079
5080    // note: I had some concerns that using this instead of the ugly macros above was affecting performance
5081    //       need to re-check carefully and if no regressions are observerd - remove the macros
5082    //       the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
5083    //         is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
5084    //const short NS10 = FC_flash_attn_ext_ns10;
5085    //const short NS20 = FC_flash_attn_ext_ns20;
5086
5087    constexpr short KV   = 8;
5088
5089    constexpr short DK4  = DK/4;
5090    constexpr short DK8  = DK/8;
5091    constexpr short DK16 = DK/16;
5092    constexpr short DV4  = DV/4;
5093  //constexpr short DV8  = DV/8;
5094    constexpr short DV16 = DV/16;
5095
5096    constexpr short PV   = PAD2(DV, 64);
5097    constexpr short PV4  = PV/4;
5098    constexpr short PV8  = PV/8;
5099  //constexpr short PV16 = PV/16;
5100
5101    constexpr short NW  = N_SIMDWIDTH;
5102    constexpr short NQ  = Q/NSG;
5103    constexpr short SH  = 2*C; // shared memory per simdgroup (s_t == float)
5104
5105    constexpr short TS = 2*SH;
5106    constexpr short T  = DK + 2*PV; // shared memory size per query in (half)
5107
5108    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 + 0*T); // holds the query data
5109    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
5110    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
5111    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
5112    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
5113    threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
5114
5115    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
5116    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
5117
5118    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
5119    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
5120
5121    // mask storage in shared mem
5122    threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
5123
5124    // per-query mask pointers
5125    device const half2 * pm2[NQ];
5126
5127    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5128        const short j = jj*NSG + sgitg;
5129
5130        pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
5131    }
5132
5133    {
5134        const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5135        const int32_t nblk0 = ((args.ne11 + C - 1)/C);
5136
5137        blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
5138    }
5139
5140    {
5141        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
5142
5143        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5144        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5145
5146        k += ikv2*args.nb12 + ikv3*args.nb13;
5147        v += ikv2*args.nb22 + ikv3*args.nb23;
5148    }
5149
5150    // load heads from Q to shared memory
5151    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5152        const short j = jj*NSG + sgitg;
5153
5154        device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
5155
5156        for (short i = tiisg; i < DK4; i += NW) {
5157            if (iq1 + j < args.ne01) {
5158                sq4[j*DK4 + i] = (q4_t) q4[i];
5159            } else {
5160                sq4[j*DK4 + i] = 0;
5161            }
5162        }
5163    }
5164
5165    // zero out
5166    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5167        const short j = jj*NSG + sgitg;
5168
5169        for (short i = tiisg; i < DV4; i += NW) {
5170            so4[j*PV4 + i] = 0;
5171        }
5172
5173        for (short i = tiisg; i < SH; i += NW) {
5174            ss[j*SH + i] = 0.0f;
5175        }
5176    }
5177
5178    threadgroup_barrier(mem_flags::mem_threadgroup);
5179
5180    float S[NQ] = { [0 ... NQ-1] = 0.0f };
5181
5182    {
5183        float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
5184
5185        float slope = 1.0f;
5186
5187        // ALiBi
5188        if (FC_flash_attn_ext_has_bias) {
5189            const short h = iq2;
5190
5191            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
5192            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
5193
5194            slope = pow(base, exph);
5195        }
5196
5197        // loop over the KV cache
5198        // each simdgroup handles blocks of Q rows and C columns
5199        for (int ic0 = 0; ; ++ic0) {
5200            int ic = ic0*C;
5201            if (ic >= args.ne11) {
5202                break;
5203            }
5204
5205            // the last partial chunk uses the pad buffer as source
5206            if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
5207                k    = pad;
5208                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
5209                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
5210
5211                const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5212                const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5213
5214                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
5215                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
5216
5217                if (!FC_flash_attn_ext_has_mask) {
5218                    threadgroup half * sm = (threadgroup half *) (sm2);
5219
5220                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5221                        const short j = jj*NSG + sgitg;
5222
5223                        for (short i = tiisg; i < C; i += NW) {
5224                            if (ic + i >= args.ne11) {
5225                                sm[2*j*SH + i] = -MAXHALF;
5226                            }
5227                        }
5228                    }
5229                } else {
5230                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5231                        const short j = jj*NSG + sgitg;
5232
5233                        pm2[jj] = (device const half2 *) ((device const half *) mask +
5234                                (iq1 + j)*C +
5235                                (iq2%args.ne32)*(C*args.ne31) +
5236                                (iq3%args.ne33)*(C*args.ne31*args.ne32));
5237                    }
5238                }
5239
5240                ic = 0;
5241            }
5242
5243            char blk_cur = 1;
5244
5245            // read the mask into shared mem
5246            if (FC_flash_attn_ext_has_mask) {
5247                blk_cur = blk[ic0];
5248
5249                if (blk_cur == 0) {
5250                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5251                        pm2[jj] += NW;
5252                    }
5253
5254                    continue;
5255                }
5256
5257                if (blk_cur == 1) {
5258                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5259                        const short j = jj*NSG + sgitg;
5260
5261                        if (FC_flash_attn_ext_bc_mask) {
5262                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5263                        } else {
5264                            sm2[j*SH + tiisg] = pm2[jj][tiisg];
5265                        }
5266
5267                        pm2[jj] += NW;
5268                    }
5269                } else if (blk_cur == 2) {
5270                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5271                        pm2[jj] += NW;
5272                    }
5273                }
5274
5275#if 0
5276                // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
5277
5278                threadgroup_barrier(mem_flags::mem_threadgroup);
5279
5280                // used to detect blocks full of -INF
5281                // skip only when the entire threadgroup is masked
5282                half2 smax2(-MAXHALF/2, -MAXHALF/2);
5283
5284                FOR_UNROLL (short j = 0; j < Q; ++j) {
5285                    smax2 = max(smax2, sm2[j*SH + tiisg]);
5286                }
5287
5288                smax2 = simd_max(smax2);
5289
5290                if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
5291                    // this barrier is important
5292                    threadgroup_barrier(mem_flags::mem_threadgroup);
5293
5294                    continue;
5295                }
5296#endif
5297            }
5298
5299            // Q*K^T
5300            // this is compile-time check, so it does not have runtime overhead
5301            if (is_same<kd4x4_t, k4x4_t>::value) {
5302                // we can read directly from global memory
5303                device      const k_t * pk = (device const k_t *) (k + ic*args.nb11);
5304                threadgroup const q_t * pq = sq;
5305                threadgroup       s_t * ps = ss;
5306
5307                pk += sgitg*(8*NS10);
5308                ps += sgitg*(8*1);
5309
5310                static_assert((C/8) % NSG == 0, "");
5311
5312                constexpr short NC = (C/8)/NSG;
5313
5314                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5315                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5316
5317                    if (DK % 16 != 0) {
5318                        k8x8_t mk;
5319                        q8x8_t mq;
5320
5321                        FOR_UNROLL (short i = 0; i < DK8; ++i) {
5322                            simdgroup_barrier(mem_flags::mem_none);
5323
5324                            simdgroup_load(mk, pk + 8*i, NS10, 0, true);
5325                            simdgroup_load(mq, pq + 8*i, DK);
5326
5327                            simdgroup_barrier(mem_flags::mem_none);
5328
5329                            simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5330                        }
5331                    } else {
5332                        k8x8_t mk[2];
5333                        q8x8_t mq[2];
5334
5335                        // note: too much unroll can tank the performance for large heads
5336                        #pragma unroll (MIN(DK8/2, 4*NSG))
5337                        for (short i = 0; i < DK8/2; ++i) {
5338                            simdgroup_barrier(mem_flags::mem_none);
5339
5340                            simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
5341                            simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
5342
5343                            simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
5344                            simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
5345
5346                            simdgroup_barrier(mem_flags::mem_none);
5347
5348                            simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
5349                            simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
5350                        }
5351                    }
5352
5353                    simdgroup_store(mqk, ps, SH, 0, false);
5354
5355                    pk += 8*(NSG*NS10);
5356                    ps += 8*(NSG);
5357                }
5358            } else {
5359                // TODO: this is the quantized K cache branch - not optimized yet
5360                for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
5361                    const short cc = ccc*NSG + sgitg;
5362
5363                    const short tx = tiisg%4;
5364                    const short ty = tiisg/4;
5365
5366                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5367
5368                    for (short ii = 0; ii < DK16; ii += 4) {
5369                        device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
5370
5371                        if (DK16%4 == 0) {
5372                            // the head is evenly divisible by 4*16 = 64, so no need for bound checks
5373                            {
5374                                k4x4_t tmp;
5375                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
5376                                sk4x4[4*ty + tx] = tmp;
5377                            }
5378
5379                            simdgroup_barrier(mem_flags::mem_threadgroup);
5380
5381                            FOR_UNROLL (short k = 0; k < 4; ++k) {
5382                                k8x8_t mk;
5383                                q8x8_t mq;
5384
5385                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
5386                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
5387                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5388
5389                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
5390                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
5391                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5392                            }
5393                        } else {
5394                            if (ii + tx < DK16) {
5395                                k4x4_t tmp;
5396                                deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
5397                                sk4x4[4*ty + tx] = tmp;
5398                            }
5399
5400                            simdgroup_barrier(mem_flags::mem_threadgroup);
5401
5402                            for (short k = 0; k < 4 && ii + k < DK16; ++k) {
5403                                k8x8_t mk;
5404                                q8x8_t mq;
5405
5406                                simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
5407                                simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
5408                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5409
5410                                simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
5411                                simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
5412                                simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5413                            }
5414                        }
5415                    }
5416
5417                    simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
5418                }
5419            }
5420
5421            threadgroup_barrier(mem_flags::mem_threadgroup);
5422
5423            // online softmax
5424            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5425                const short j = jj*NSG + sgitg;
5426
5427                const float m = M[jj];
5428
5429                // scale and apply the logitcap / mask
5430                float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
5431
5432                if (FC_flash_attn_ext_has_scap) {
5433                    s2 = args.logit_softcap*precise::tanh(s2);
5434                }
5435
5436                // mqk = mqk + slope*mask
5437                if (blk_cur != 2) {
5438                    if (FC_flash_attn_ext_has_bias) {
5439                        s2 += s2_t(sm2[j*SH + tiisg])*slope;
5440                    } else {
5441                        s2 += s2_t(sm2[j*SH + tiisg]);
5442                    }
5443                }
5444
5445                M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
5446
5447                const float  ms  = exp(m  - M[jj]);
5448                const float2 vs2 = exp(s2 - M[jj]);
5449
5450                S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
5451
5452                // the P matrix from the paper (Q rows, C columns)
5453                ss2[j*SH/2 + tiisg] = vs2;
5454
5455                if (DV4 % NW == 0) {
5456                    FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
5457                        const short i = ii*NW + tiisg;
5458
5459                        so4[j*PV4 + i] *= ms;
5460                    }
5461                } else {
5462                    for (short i = tiisg; i < DV4; i += NW) {
5463                        so4[j*PV4 + i] *= ms;
5464                    }
5465                }
5466            }
5467
5468            threadgroup_barrier(mem_flags::mem_threadgroup);
5469
5470            // O = O + (Q*K^T)*V
5471            {
5472                // we can read directly from global memory
5473                if (is_same<vd4x4_t, v4x4_t>::value) {
5474                    static_assert(PV8 % NSG == 0, "");
5475
5476                    constexpr short NO = PV8/NSG;
5477
5478                    o8x8_t lo[NO];
5479
5480                    {
5481                        auto sot = so + 8*sgitg;
5482
5483                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
5484                            simdgroup_load(lo[ii], sot, PV, 0, false);
5485
5486                            sot += 8*NSG;
5487                        }
5488                    }
5489
5490                    {
5491                        device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
5492
5493                        pv += 8*sgitg;
5494
5495                        if (DV <= 64) {
5496                            FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
5497                                s8x8_t vs;
5498                                simdgroup_load(vs, ss + 8*cc, SH, 0, false);
5499
5500                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5501                                    v8x8_t mv[2];
5502
5503                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
5504                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
5505
5506                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
5507                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
5508                                }
5509
5510                                pv  += 8*NS20;
5511                            }
5512                        } else {
5513                            constexpr short NC = (C/8)/2;
5514
5515                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5516                                s8x8_t vs[2];
5517
5518                                simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
5519                                simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
5520
5521                                FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5522                                    v8x8_t mv[4];
5523
5524                                    simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5525                                    simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5526                                    simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5527                                    simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5528
5529                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
5530                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
5531                                    simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
5532                                    simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
5533                                }
5534
5535                                pv  += 2*8*NS20;
5536                            }
5537                        }
5538                    }
5539
5540                    {
5541                        auto sot = so + 8*sgitg;
5542
5543                        FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
5544                            simdgroup_store(lo[ii], sot, PV, 0, false);
5545
5546                            sot += 8*NSG;
5547                        }
5548                    }
5549                } else {
5550                    // TODO: this is the quantized V cache branch - not optimized yet
5551
5552                    const short tx = tiisg%4;
5553                    const short ty = tiisg/4;
5554
5555                    for (short cc = 0; cc < C/8; ++cc) {
5556                        s8x8_t vs;
5557                        simdgroup_load(vs, ss + 8*cc, SH, 0, false);
5558
5559                        for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
5560                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
5561
5562                            if (DV16%4 == 0) {
5563                                // no need for bound checks
5564                                {
5565                                    v4x4_t tmp;
5566                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
5567                                    sv4x4[4*ty + tx] = tmp;
5568                                }
5569
5570                                simdgroup_barrier(mem_flags::mem_threadgroup);
5571
5572                                FOR_UNROLL (short k = 0; k < 4; ++k) {
5573                                    v8x8_t mv[2];
5574                                    o8x8_t lo[2];
5575
5576                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
5577                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
5578                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5579                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5580
5581                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
5582                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
5583
5584                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5585                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5586                                }
5587                            } else {
5588                                if (ii + tx < DV16) {
5589                                    v4x4_t tmp;
5590                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
5591                                    sv4x4[4*ty + tx] = tmp;
5592                                }
5593
5594                                simdgroup_barrier(mem_flags::mem_threadgroup);
5595
5596                                for (short k = 0; k < 4 && ii + k < DV16; ++k) {
5597                                    v8x8_t mv[2];
5598                                    o8x8_t lo[2];
5599
5600                                    simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
5601                                    simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
5602                                    simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5603                                    simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5604
5605                                    simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
5606                                    simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
5607
5608                                    simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5609                                    simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5610                                }
5611                            }
5612                        }
5613                    }
5614                }
5615            }
5616
5617            threadgroup_barrier(mem_flags::mem_threadgroup);
5618        }
5619
5620        if (FC_flash_attn_ext_has_sinks) {
5621            FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5622                const short j = jj*NSG + sgitg;
5623
5624                const float m = M[jj];
5625                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
5626
5627                M[jj] = simd_max(max(M[jj], s));
5628
5629                const float ms = exp(m - M[jj]);
5630                const float vs = exp(s - M[jj]);
5631
5632                S[jj] = S[jj]*ms + simd_sum(vs);
5633
5634                for (short i = tiisg; i < DV4; i += NW) {
5635                    so4[j*PV4 + i] *= ms;
5636                }
5637            }
5638        }
5639    }
5640
5641    // store to global memory
5642    for (short jj = 0; jj < NQ; ++jj) {
5643        const short j = jj*NSG + sgitg;
5644        if (iq1 + j >= args.ne01) {
5645            break;
5646        }
5647
5648        device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
5649
5650        const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
5651
5652        if (DV4 % NW == 0) {
5653            FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
5654                const short i = ii*NW + tiisg;
5655
5656                dst4[i] = (float4) so4[j*PV4 + i]*scale;
5657            }
5658        } else {
5659            for (short i = tiisg; i < DV4; i += NW) {
5660                dst4[i] = (float4) so4[j*PV4 + i]*scale;
5661            }
5662        }
5663    }
5664
5665#undef NS10
5666#undef NS20
5667}
5668
5669template<
5670    typename q_t,     // query types in shared memory
5671    typename q4_t,
5672    typename q8x8_t,
5673    typename k_t,     // key types in shared memory
5674    typename k4x4_t,
5675    typename k8x8_t,
5676    typename v_t,     // value types in shared memory
5677    typename v4x4_t,
5678    typename v8x8_t,
5679    typename qk_t,    // Q*K types
5680    typename qk8x8_t,
5681    typename s_t,     // soft-max types
5682    typename s2_t,
5683    typename s8x8_t,
5684    typename o_t,     // attention accumulation types
5685    typename o4_t,
5686    typename o8x8_t,
5687    typename kd4x4_t, // key type in device memory
5688    short nl_k,
5689    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
5690    typename vd4x4_t, // value type in device memory
5691    short nl_v,
5692    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5693    short DK,         // K head size
5694    short DV,         // V head size
5695    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
5696    short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
5697kernel void kernel_flash_attn_ext(
5698        constant ggml_metal_kargs_flash_attn_ext & args,
5699        device const char * q,
5700        device const char * k,
5701        device const char * v,
5702        device const char * mask,
5703        device const char * sinks,
5704        device const char * pad,
5705        device const char * blk,
5706        device       char * dst,
5707        threadgroup  half * shmem_f16 [[threadgroup(0)]],
5708        uint3   tgpig[[threadgroup_position_in_grid]],
5709        ushort  tiisg[[thread_index_in_simdgroup]],
5710        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
5711#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
5712#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
5713    switch (FC_flash_attn_ext_nsg) {
5714      // note: disabled cases to reduce library load time
5715      //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
5716      //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
5717        case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
5718        case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
5719    }
5720#undef FWD_TMPL
5721#undef FWD_ARGS
5722}
5723
5724// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
5725//       template to be able to explore different combinations
5726//
5727#define FA_TYPES \
5728    half,   half4,     simdgroup_half8x8,  \
5729    half,   half4x4,   simdgroup_half8x8,  \
5730    half,   half4x4,   simdgroup_half8x8,  \
5731    float,             simdgroup_float8x8, \
5732    float,  float2,    simdgroup_float8x8, \
5733    float,  float4,    simdgroup_float8x8
5734    //half,   half4,     simdgroup_half8x8
5735
5736#define FA_TYPES_BF \
5737    bfloat, bfloat4,   simdgroup_bfloat8x8, \
5738    bfloat, bfloat4x4, simdgroup_bfloat8x8, \
5739    bfloat, bfloat4x4, simdgroup_bfloat8x8, \
5740    float,             simdgroup_float8x8,  \
5741    float,  float2,    simdgroup_float8x8,  \
5742    half,   half4,     simdgroup_half8x8
5743    //float,  float4,    simdgroup_float8x8
5744
5745#define FA_TYPES_F32 \
5746    half,   half4,     simdgroup_half8x8,  \
5747    float,  float4x4,  simdgroup_float8x8, \
5748    float,  float4x4,  simdgroup_float8x8, \
5749    float,             simdgroup_float8x8, \
5750    float,  float2,    simdgroup_float8x8, \
5751    float,  float4,    simdgroup_float8x8
5752    //half,   half4,     simdgroup_half8x8
5753
5754typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
5755
5756template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  32,  32>;
5757template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  40,  40>;
5758template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  48,  48>;
5759template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  64,  64>;
5760template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  72,  72>;
5761template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  80,  80>;
5762template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  96,  96>;
5763template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  112, 112>;
5764template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  128, 128>;
5765template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  192, 192>;
5766template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  192, 128>;
5767template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  256, 256>;
5768template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4,   1, dequantize_f32,  float4x4,   1, dequantize_f32,  576, 512>;
5769
5770template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  32,  32>;
5771template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  40,  40>;
5772template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  48,  48>;
5773template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64,  64>;
5774template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  72,  72>;
5775template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80,  80>;
5776template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96,  96>;
5777template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112, 112>;
5778template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128, 128>;
5779template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 192>;
5780template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  192, 128>;
5781template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256, 256>;
5782template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  576, 512>;
5783
5784#if defined(GGML_METAL_HAS_BF16)
5785template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 32,  32>;
5786template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 40,  40>;
5787template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 48,  48>;
5788template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
5789template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 72,  72>;
5790template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
5791template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
5792template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
5793template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
5794template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
5795template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
5796template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
5797template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;
5798#endif
5799
5800template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32,  32>;
5801template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40,  40>;
5802template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48,  48>;
5803template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
5804template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72,  72>;
5805template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80,  80>;
5806template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96,  96>;
5807template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
5808template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
5809template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
5810template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
5811template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
5812template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
5813
5814template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32,  32>;
5815template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40,  40>;
5816template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48,  48>;
5817template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64,  64>;
5818template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72,  72>;
5819template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80,  80>;
5820template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96,  96>;
5821template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
5822template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
5823template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
5824template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
5825template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
5826template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
5827
5828template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32,  32>;
5829template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40,  40>;
5830template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48,  48>;
5831template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64,  64>;
5832template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72,  72>;
5833template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80,  80>;
5834template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96,  96>;
5835template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
5836template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
5837template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
5838template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
5839template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
5840template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
5841
5842template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32,  32>;
5843template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40,  40>;
5844template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48,  48>;
5845template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64,  64>;
5846template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72,  72>;
5847template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80,  80>;
5848template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96,  96>;
5849template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
5850template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
5851template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
5852template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
5853template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
5854template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
5855
5856template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32,  32>;
5857template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40,  40>;
5858template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48,  48>;
5859template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64,  64>;
5860template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72,  72>;
5861template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80,  80>;
5862template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96,  96>;
5863template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
5864template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
5865template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
5866template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
5867template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
5868template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES,    block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
5869
5870#undef FA_TYPES
5871#undef FA_TYPES_BF
5872#undef FA_TYPES_F32
5873
5874constant bool FC_flash_attn_ext_vec_has_mask  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
5875constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
5876constant bool FC_flash_attn_ext_vec_has_bias  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
5877constant bool FC_flash_attn_ext_vec_has_scap  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
5878constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
5879
5880//constant float FC_flash_attn_ext_vec_scale         [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
5881//constant float FC_flash_attn_ext_vec_max_bias      [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
5882//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
5883
5884constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
5885constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
5886constant int32_t FC_flash_attn_ext_vec_nsg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
5887constant int32_t FC_flash_attn_ext_vec_nwg  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
5888
5889template<
5890    typename q4_t,  // query types in shared memory
5891    typename k4_t,  // key types in shared memory
5892    typename v4_t,  // value types in shared memory
5893    typename qk_t,  // Q*K types
5894    typename s_t,   // soft-max types
5895    typename s4_t,
5896    typename o4_t,  // attention accumulation types
5897    typename kd4_t, // key type in device memory
5898    short nl_k,
5899    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
5900    typename vd4_t, // value type in device memory
5901    short nl_v,
5902    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
5903    short DK,       // K head size
5904    short DV,       // V head size
5905    short NE = 4,   // head elements per thread
5906    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup
5907    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
5908kernel void kernel_flash_attn_ext_vec(
5909        constant ggml_metal_kargs_flash_attn_ext_vec & args,
5910        device const char * q,
5911        device const char * k,
5912        device const char * v,
5913        device const char * mask,
5914        device const char * sinks,
5915        device const char * pad,
5916        device       char * dst,
5917        threadgroup  half * shmem_f16 [[threadgroup(0)]],
5918        uint3   tgpig[[threadgroup_position_in_grid]],
5919        ushort  tiisg[[thread_index_in_simdgroup]],
5920        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
5921    static_assert(DK % 32 == 0, "DK must be divisible by 32");
5922    static_assert(DV % 32 == 0, "DV must be divisible by 32");
5923
5924#define NWG  (FC_flash_attn_ext_vec_nwg)
5925#define NSG  (FC_flash_attn_ext_vec_nsg)
5926
5927#define NS10 (FC_flash_attn_ext_vec_ns10)
5928#define NS20 (FC_flash_attn_ext_vec_ns20)
5929
5930    const short iwg = tgpig[2]%NWG;
5931
5932    const ushort iq3 = tgpig[2]/NWG;
5933    const ushort iq2 = tgpig[1];
5934    const ushort iq1 = tgpig[0];
5935
5936    constexpr short DK4 = DK/4;
5937    constexpr short DV4 = DV/4;
5938
5939    constexpr short PK  = PAD2(DK, 128);
5940    constexpr short PK4 = PK/4;
5941
5942    constexpr short PV  = PAD2(DV, 128);
5943    constexpr short PV4 = PV/4;
5944
5945    constexpr short NW  = N_SIMDWIDTH;
5946    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
5947    constexpr short SH  = 4*C;   // shared memory per simdgroup
5948
5949    static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
5950    static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
5951
5952    const short T = PK + NSG*SH; // shared memory size per query in (half)
5953
5954  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
5955    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
5956    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention
5957    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t
5958    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
5959    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results
5960
5961    // store the result for all queries in shared memory (the O matrix from the paper)
5962    so4 += tiisg;
5963
5964    {
5965        q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
5966
5967        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5968        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5969
5970        k += ikv2*args.nb12 + ikv3*args.nb13;
5971        v += ikv2*args.nb22 + ikv3*args.nb23;
5972    }
5973
5974    // load heads from Q to shared memory
5975    device const float4 * q4 = (device const float4 *) ((device const char *) q);
5976
5977    if (iq1 < args.ne01) {
5978        for (short i = tiisg; i < PK4; i += NW) {
5979            if (i < DK4) {
5980                sq4[i] = (q4_t) q4[i];
5981            } else {
5982                sq4[i] = (q4_t) 0.0f;
5983            }
5984        }
5985    }
5986
5987    // zero out so
5988    for (short i = 0; i < DV4/NL; ++i) {
5989        so4[i*NL] = (o4_t) 0.0f;
5990    }
5991
5992    // zero out shared memory SH
5993    for (short i = tiisg; i < SH/4; i += NW) {
5994        ss4[i] = (s4_t) 0.0f;
5995    }
5996
5997    threadgroup_barrier(mem_flags::mem_threadgroup);
5998
5999    {
6000        float S = 0.0f;
6001        float M = -FLT_MAX/2;
6002
6003        // thread indices inside the simdgroup
6004        const short tx = tiisg%NL;
6005        const short ty = tiisg/NL;
6006
6007        // pointer to the mask
6008        device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
6009
6010        float slope = 1.0f;
6011
6012        // ALiBi
6013        if (FC_flash_attn_ext_vec_has_bias) {
6014            const short h = iq2;
6015
6016            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
6017            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
6018
6019            slope = pow(base, exph);
6020        }
6021
6022        // loop over the KV cache
6023        // each simdgroup handles blocks of Q rows and C columns
6024        for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
6025            int ic = ic0*C;
6026            if (ic >= args.ne11) {
6027                break;
6028            }
6029
6030            // the last partial chunk uses the pad buffer as source
6031            if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
6032                k    = pad;
6033                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
6034                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
6035
6036                const short ikv2 = iq2/(args.ne02/args.ne_12_2);
6037                const short ikv3 = iq3/(args.ne03/args.ne_12_3);
6038
6039                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
6040                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
6041
6042                if (!FC_flash_attn_ext_vec_has_mask) {
6043                    if (ic + tiisg >= args.ne11) {
6044                        sm[tiisg] = -MAXHALF;
6045                    }
6046                } else {
6047                    pm = (device const half *) (mask) +
6048                        iq1*C +
6049                        (iq2%args.ne32)*(C*args.ne31) +
6050                        (iq3%args.ne33)*(C*args.ne31*args.ne32);
6051                }
6052
6053                ic = 0;
6054            }
6055
6056            if (FC_flash_attn_ext_vec_has_mask) {
6057                sm[tiisg] = pm[ic + tiisg];
6058            }
6059
6060            // skip -INF blocks
6061            if (simd_max(sm[tiisg]) <= -MAXHALF) {
6062                continue;
6063            }
6064
6065            // Q*K^T
6066            {
6067                device      const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
6068                threadgroup const q4_t * pq4 = sq4;
6069
6070                pk4 += ty*NS10/4 + tx;
6071                pq4 += tx;
6072
6073                qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
6074
6075                // each simdgroup processes 1 query and NE (NW/NL) cache elements
6076                FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6077                    if (is_same<kd4_t, k4_t>::value) {
6078                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
6079                            mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 +  ii*NL], (float4) pq4[ii*NL]);
6080                        }
6081                    } else {
6082                        device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
6083
6084                        k4_t mk;
6085
6086                        FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
6087                            const short i = ii*NL + tx;
6088
6089                            deq_k_t4(pk + i/nl_k, i%nl_k, mk);
6090
6091                            mqk[cc] += dot((float4) mk, (float4) sq4[i]);
6092                        }
6093                    }
6094
6095                    if (NE == 1) {
6096                        mqk[cc] = simd_sum(mqk[cc]);
6097                    } else {
6098                        // simdgroup reduce (NE = 4)
6099                        // [ 0 ..  7] -> [ 0]
6100                        // [ 8 .. 15] -> [ 8]
6101                        // [16 .. 23] -> [16]
6102                        // [24 .. 31] -> [24]
6103                        if (NE <= 1) {
6104                            mqk[cc] += simd_shuffle_down(mqk[cc], 16);
6105                        }
6106                        if (NE <= 2) {
6107                            mqk[cc] += simd_shuffle_down(mqk[cc],  8);
6108                        }
6109                        if (NE <= 4) {
6110                            mqk[cc] += simd_shuffle_down(mqk[cc],  4);
6111                        }
6112                        if (NE <= 8) {
6113                            mqk[cc] += simd_shuffle_down(mqk[cc],  2);
6114                        }
6115                        if (NE <= 16) {
6116                            mqk[cc] += simd_shuffle_down(mqk[cc],  1);
6117                        }
6118
6119                        // broadcast
6120                        mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
6121                    }
6122                }
6123
6124                if (FC_flash_attn_ext_vec_has_mask &&
6125                   !FC_flash_attn_ext_vec_has_scap &&
6126                   !FC_flash_attn_ext_vec_has_bias) {
6127                    ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
6128                } else {
6129                    mqk[tx] *= args.scale;
6130
6131                    if (FC_flash_attn_ext_vec_has_scap) {
6132                        mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
6133                    }
6134
6135                    if (FC_flash_attn_ext_vec_has_bias) {
6136                        mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
6137                    } else {
6138                        mqk[tx] += (qk_t) sm[NE*tx + ty];
6139                    }
6140
6141                    ss[NE*tx + ty] = mqk[tx];
6142                }
6143            }
6144
6145            simdgroup_barrier(mem_flags::mem_threadgroup);
6146
6147            // online softmax
6148            {
6149                const float m = M;
6150                const float s = ss[tiisg];
6151
6152                M = simd_max(max(M, s));
6153
6154                const float ms = exp(m - M);
6155                const float vs = exp(s - M);
6156
6157                S = S*ms + simd_sum(vs);
6158
6159                // the P matrix from the paper (Q rows, C columns)
6160                ss[tiisg] = vs;
6161
6162                // O = diag(ms)*O
6163                if ((DV4/NL % NW == 0) || ty == 0) {
6164                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6165                        so4[ii*NL] *= ms;
6166                    }
6167                }
6168            }
6169
6170            simdgroup_barrier(mem_flags::mem_threadgroup);
6171
6172            // O = O + (Q*K^T)*V
6173            {
6174                o4_t lo[DV4/NL];
6175                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6176                    lo[ii] = 0.0f;
6177                }
6178
6179                if (is_same<vd4_t, v4_t>::value) {
6180                    device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
6181
6182                    pv4 += ty*NS20/4 + tx;
6183
6184                    const auto sst = ss + ty;
6185
6186                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6187                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6188                            lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
6189                        }
6190                    }
6191                } else {
6192                    FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6193                        device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
6194
6195                        FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6196                            const short i = ii*NL + tx;
6197
6198                            v4_t mv;
6199                            deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
6200
6201                            lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
6202                        }
6203                    }
6204                }
6205
6206                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6207                    if (NE > 1) {
6208                        lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
6209                        lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
6210                        lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
6211                        lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
6212                    }
6213
6214                    if (NE > 2) {
6215                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  8);
6216                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  8);
6217                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  8);
6218                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  8);
6219                    }
6220
6221                    if (NE > 4) {
6222                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  4);
6223                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  4);
6224                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  4);
6225                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  4);
6226                    }
6227
6228                    if (NE > 8) {
6229                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  2);
6230                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  2);
6231                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  2);
6232                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  2);
6233                    }
6234
6235                    if (NE > 16) {
6236                        lo[ii][0] += simd_shuffle_down(lo[ii][0],  1);
6237                        lo[ii][1] += simd_shuffle_down(lo[ii][1],  1);
6238                        lo[ii][2] += simd_shuffle_down(lo[ii][2],  1);
6239                        lo[ii][3] += simd_shuffle_down(lo[ii][3],  1);
6240                    }
6241                }
6242
6243                if ((DV4/NL % NW == 0) || ty == 0) {
6244                    FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6245                        so4[ii*NL] += lo[ii];
6246                    }
6247                }
6248            }
6249        }
6250
6251        if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
6252            const float m = M;
6253            const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
6254
6255            M = simd_max(max(M, s));
6256
6257            const float ms = exp(m - M);
6258            const float vs = exp(s - M);
6259
6260            S = S*ms + simd_sum(vs);
6261
6262            if ((DV4/NL % NW == 0) || ty == 0) {
6263                FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6264                    so4[ii*NL] *= ms;
6265                }
6266            }
6267        }
6268
6269        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6270        if (tiisg == 0) {
6271            ss[0] = (s_t) S;
6272            ss[1] = (s_t) M;
6273        }
6274    }
6275
6276    so4 -= tiisg;
6277
6278    threadgroup_barrier(mem_flags::mem_threadgroup);
6279
6280    // parallel reduce
6281    for (short r = NSG/2; r > 0; r >>= 1) {
6282        if (sgitg < r) {
6283            const float S0 = ss[           0];
6284            const float S1 = ss[r*(SH/2) + 0];
6285
6286            const float M0 = ss[           1];
6287            const float M1 = ss[r*(SH/2) + 1];
6288
6289            const float M = max(M0, M1);
6290
6291            const float ms0 = exp(M0 - M);
6292            const float ms1 = exp(M1 - M);
6293
6294            const float S = S0*ms0 + S1*ms1;
6295
6296            if (tiisg == 0) {
6297                ss[0] = S;
6298                ss[1] = M;
6299            }
6300
6301            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6302            for (short i = tiisg; i < DV4; i += NW) {
6303                so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
6304            }
6305        }
6306
6307        threadgroup_barrier(mem_flags::mem_threadgroup);
6308    }
6309
6310    // final rescale with 1/S and store to global memory
6311    if (sgitg == 0) {
6312        const int64_t nrows = args.ne3*args.ne2*args.ne1;
6313        const int64_t rid   = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
6314
6315        device float4 * dst4 = (device float4 *) dst;
6316        device float  * dst1 = (device float  *) dst + nrows*DV*NWG; // the S and M are stored after the results
6317
6318        const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
6319
6320        // interleave the workgroup data
6321        for (short i = tiisg; i < DV4; i += NW) {
6322            dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
6323        }
6324
6325        // store S and M
6326        if (NWG > 1) {
6327            if (tiisg == 0) {
6328                dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
6329                dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
6330            }
6331        }
6332    }
6333
6334#undef NWG
6335#undef NSG
6336#undef NS10
6337#undef NS20
6338}
6339
6340// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
6341//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
6342//
6343#define FA_TYPES \
6344           half4,  \
6345           half4,  \
6346           half4,  \
6347    float,         \
6348    float, float4, \
6349           float4
6350
6351#define FA_TYPES_F32 \
6352           half4,  \
6353           float4, \
6354           float4, \
6355    float,         \
6356    float, float4, \
6357           float4
6358
6359typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
6360
6361template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  32, 32, 4>;
6362template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  32, 32, 4>;
6363#if defined(GGML_METAL_HAS_BF16)
6364template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 32, 32, 4>;
6365#endif
6366template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 32, 32, 4>;
6367template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 32, 32, 4>;
6368template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 32, 32, 4>;
6369template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 32, 32, 4>;
6370template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 32, 32, 4>;
6371
6372template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  64, 64, 2>;
6373template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  64, 64, 2>;
6374#if defined(GGML_METAL_HAS_BF16)
6375template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 64, 64, 2>;
6376#endif
6377template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 64, 64, 2>;
6378template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 64, 64, 2>;
6379template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 64, 64, 2>;
6380template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 64, 64, 2>;
6381template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 64, 64, 2>;
6382
6383template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  96, 96, 4>;
6384template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]]    kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  96, 96, 4>;
6385#if defined(GGML_METAL_HAS_BF16)
6386template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 96, 96, 4>;
6387#endif
6388template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 96, 96, 4>;
6389template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 96, 96, 4>;
6390template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 96, 96, 4>;
6391template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 96, 96, 4>;
6392template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]]   kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 96, 96, 4>;
6393
6394template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  128, 128, 1>;
6395template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 1>;
6396#if defined(GGML_METAL_HAS_BF16)
6397template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 128, 128, 1>;
6398#endif
6399template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 128, 128, 1>;
6400template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 128, 128, 1>;
6401template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 128, 128, 1>;
6402template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 128, 128, 1>;
6403template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 128, 128, 1>;
6404
6405template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  192, 192, 2>;
6406template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 192, 2>;
6407#if defined(GGML_METAL_HAS_BF16)
6408template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 192, 2>;
6409#endif
6410template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 192, 2>;
6411template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 192, 2>;
6412template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 192, 2>;
6413template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 192, 2>;
6414template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 192, 2>;
6415
6416template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  192, 128, 2>;
6417template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  192, 128, 2>;
6418#if defined(GGML_METAL_HAS_BF16)
6419template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 192, 128, 2>;
6420#endif
6421template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 192, 128, 2>;
6422template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 192, 128, 2>;
6423template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 192, 128, 2>;
6424template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 192, 128, 2>;
6425template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 192, 128, 2>;
6426
6427template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  256, 256, 1>;
6428template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  256, 256, 1>;
6429#if defined(GGML_METAL_HAS_BF16)
6430template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 256, 256, 1>;
6431#endif
6432template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 256, 256, 1>;
6433template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 256, 256, 1>;
6434template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 256, 256, 1>;
6435template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 256, 256, 1>;
6436template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 256, 256, 1>;
6437
6438template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4,     1, dequantize_f32_t4,  float4,      1, dequantize_f32_t4,  576, 512, 2>;
6439template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     half4,      1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  576, 512, 2>;
6440#if defined(GGML_METAL_HAS_BF16)
6441template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     bfloat4,    1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 576, 512, 2>;
6442#endif
6443template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_0, 8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 576, 512, 2>;
6444template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q4_1, 8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 576, 512, 2>;
6445template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_0, 8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 576, 512, 2>;
6446template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q5_1, 8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 576, 512, 2>;
6447template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 576, 512, 2>;
6448
6449#undef FA_TYPES
6450#undef FA_TYPES_F32
6451
6452constant int32_t FC_flash_attn_ext_vec_reduce_DV  [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
6453constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
6454
6455kernel void kernel_flash_attn_ext_vec_reduce(
6456        constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
6457        device  const char * htmp,
6458        device        char * dst,
6459        uint   tgpig[[threadgroup_position_in_grid]],
6460        ushort tiisg[[thread_index_in_simdgroup]],
6461        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6462#define NWG (FC_flash_attn_ext_vec_reduce_NWG)
6463#define DV  (FC_flash_attn_ext_vec_reduce_DV)
6464
6465    const uint64_t rid = tgpig;
6466
6467    const short iwg = tiisg;
6468
6469    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*NWG;
6470
6471    float S = ss[rid*(2*NWG) + 2*iwg + 0];
6472    float M = ss[rid*(2*NWG) + 2*iwg + 1];
6473
6474    const float m  = simd_max(M);
6475    const float ms = exp(M - m);
6476
6477    S = simd_sum(S*ms);
6478    S = S == 0.0f ? 0.0f : 1.0f/S;
6479
6480    const short DV4 = DV/4;
6481
6482    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
6483    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;
6484
6485    for (short i = sgitg; i < DV4; i += NWG) {
6486        const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
6487
6488        if (iwg == 0) {
6489            dst4[i] = v*S;
6490        }
6491    }
6492
6493#undef NWG
6494#undef DV
6495}
6496
6497template<typename T0, typename T1>
6498kernel void kernel_cpy_t_t(
6499        constant ggml_metal_kargs_cpy & args,
6500        device  const char * src0,
6501        device        char * dst,
6502        uint3   tgpig[[threadgroup_position_in_grid]],
6503        ushort  tiitg[[thread_index_in_threadgroup]],
6504        ushort3   ntg[[threads_per_threadgroup]]) {
6505    const int i03 = tgpig[2];
6506    const int i02 = tgpig[1];
6507    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6508    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6509
6510    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6511
6512    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6513    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6514    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6515    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6516
6517    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6518
6519    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
6520        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
6521        dst_data[i00] = (T1) src[0];
6522        break;
6523    }
6524}
6525
6526typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
6527
6528template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   float>;
6529template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   half>;
6530template [[host_name("kernel_cpy_f32_i32")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   int32_t>;
6531template [[host_name("kernel_cpy_i32_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
6532template [[host_name("kernel_cpy_i32_i32")]]   kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
6533#if defined(GGML_METAL_HAS_BF16)
6534template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy_t_t<float,   bfloat>;
6535#endif
6536template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    float>;
6537template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    half>;
6538#if defined(GGML_METAL_HAS_BF16)
6539template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  float>;
6540template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  bfloat>;
6541#endif
6542
6543template<short QK,
6544         typename block_q,
6545         void (*quantize_func)(device const float *, device block_q &)>
6546kernel void kernel_cpy_f32_q(
6547        constant ggml_metal_kargs_cpy & args,
6548        device const char * src0,
6549        device char * dst,
6550        uint3   tgpig[[threadgroup_position_in_grid]],
6551        ushort  tiitg[[thread_index_in_threadgroup]],
6552        ushort3   ntg[[threads_per_threadgroup]]) {
6553    const int i03 = tgpig[2];
6554    const int i02 = tgpig[1];
6555    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6556    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6557
6558    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6559
6560    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
6561    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
6562    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
6563    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
6564
6565    device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6566
6567    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6568        device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
6569
6570        quantize_func(src, dst_data[i00]);
6571
6572        break;
6573    }
6574}
6575
6576typedef decltype(kernel_cpy_f32_q<QK8_0,  block_q8_0,  quantize_q8_0>)  cpy_f_q_t;
6577
6578template [[host_name("kernel_cpy_f32_q8_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0,  block_q8_0,   quantize_q8_0>;
6579template [[host_name("kernel_cpy_f32_q4_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0,  block_q4_0,   quantize_q4_0>;
6580template [[host_name("kernel_cpy_f32_q4_1")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1,  block_q4_1,   quantize_q4_1>;
6581template [[host_name("kernel_cpy_f32_q5_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0,  block_q5_0,   quantize_q5_0>;
6582template [[host_name("kernel_cpy_f32_q5_1")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1,  block_q5_1,   quantize_q5_1>;
6583template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
6584
6585template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
6586kernel void kernel_cpy_q_f32(
6587        constant ggml_metal_kargs_cpy & args,
6588        device  const char * src0,
6589        device        char * dst,
6590        uint3   tgpig[[threadgroup_position_in_grid]],
6591        ushort  tiitg[[thread_index_in_threadgroup]],
6592        ushort3   ntg[[threads_per_threadgroup]]) {
6593    const int i03 = tgpig[2];
6594    const int i02 = tgpig[1];
6595    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6596    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6597
6598    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6599
6600    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6601    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6602    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6603    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6604
6605    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
6606    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
6607
6608    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6609        T4x4 temp;
6610        dequantize_func(src_data + i00/nl, i00%nl, temp);
6611        dst_data[i00] = temp;
6612
6613        break;
6614    }
6615}
6616
6617typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
6618
6619template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
6620template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
6621template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
6622template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
6623template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
6624
6625template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
6626template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
6627template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
6628template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
6629template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
6630
6631kernel void kernel_concat(
6632    constant ggml_metal_kargs_concat & args,
6633    device  const char * src0,
6634    device  const char * src1,
6635    device        char * dst,
6636    uint3   tgpig[[threadgroup_position_in_grid]],
6637    ushort3 tpitg[[thread_position_in_threadgroup]],
6638    ushort3   ntg[[threads_per_threadgroup]]) {
6639
6640    const int i3 = tgpig.z;
6641    const int i2 = tgpig.y;
6642    const int i1 = tgpig.x;
6643
6644    int o[4] = {0, 0, 0, 0};
6645    o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
6646
6647    device const float * x;
6648
6649    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
6650        if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
6651            x = (device const float *)(src0 + (i3       )*args.nb03 + (i2       )*args.nb02 + (i1       )*args.nb01 + (i0       )*args.nb00);
6652        } else {
6653            x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
6654        }
6655
6656        device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6657
6658        *y = *x;
6659    }
6660}
6661
6662template<int nr0, typename args_t>
6663void kernel_mul_mv_q2_K_f32_impl(
6664        args_t args,
6665        device const char * src0,
6666        device const char * src1,
6667        device       char * dst,
6668        threadgroup  char * shmem,
6669        uint3  tgpig,
6670        ushort tiisg,
6671        ushort sgitg) {
6672    const short NSG = FC_mul_mv_nsg;
6673
6674    const int nb = args.ne00/QK_K;
6675
6676    const int r0 = tgpig.x;
6677    const int r1 = tgpig.y;
6678    const int im = tgpig.z;
6679
6680    const int first_row = (r0 * NSG + sgitg) * nr0;
6681
6682    const uint i12 = im%args.ne12;
6683    const uint i13 = im/args.ne12;
6684
6685    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6686    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
6687
6688    device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
6689    device const float      * y = (device const float      *) (src1 + offset1);
6690
6691    float yl[32];
6692    float sumf[nr0]={0.f};
6693
6694    const short ix = tiisg/8;  // 0...3
6695    const short it = tiisg%8;  // 0...7
6696    const short iq = it/4;     // 0 or 1
6697    const short ir = it%4;     // 0...3
6698    const short is = (8*ir)/16;// 0 or 1
6699
6700    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
6701
6702    for (int ib = ix; ib < nb; ib += 4) {
6703        float4 sumy = {0.f, 0.f, 0.f, 0.f};
6704        for (short i = 0; i < 8; ++i) {
6705            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
6706            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
6707            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
6708            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
6709        }
6710
6711        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
6712        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
6713        device const half     * dh = &x[ib].d;
6714
6715        for (short row = 0; row < nr0; row++) {
6716            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
6717            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
6718            for (int i = 0; i < 8; i += 2) {
6719                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
6720                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
6721                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
6722                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
6723                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
6724                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
6725                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
6726                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
6727            }
6728            float dall = dh[0];
6729            float dmin = dh[1] * 1.f/16.f;
6730            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
6731                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
6732                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
6733                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
6734                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
6735
6736            qs += args.nb01/2;
6737            sc += args.nb01;
6738            dh += args.nb01/2;
6739        }
6740
6741        y4 += 4 * QK_K;
6742    }
6743
6744    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
6745
6746    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6747        float sum_all = simd_sum(sumf[row]);
6748        if (tiisg == 0) {
6749            dst_f32[first_row + row] = sum_all;
6750        }
6751    }
6752}
6753
6754[[host_name("kernel_mul_mv_q2_K_f32")]]
6755kernel void kernel_mul_mv_q2_K_f32(
6756        constant ggml_metal_kargs_mul_mv & args,
6757        device const char * src0,
6758        device const char * src1,
6759        device       char * dst,
6760        uint3  tgpig[[threadgroup_position_in_grid]],
6761        ushort tiisg[[thread_index_in_simdgroup]],
6762        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6763
6764    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
6765}
6766
6767template<int nr0, typename args_t>
6768void kernel_mul_mv_q3_K_f32_impl(
6769        args_t args,
6770        device const char * src0,
6771        device const char * src1,
6772        device       char * dst,
6773        threadgroup  char * shmem,
6774        uint3  tgpig,
6775        ushort tiisg,
6776        ushort sgitg) {
6777    const short NSG = FC_mul_mv_nsg;
6778
6779    const int nb = args.ne00/QK_K;
6780
6781    const int r0 = tgpig.x;
6782    const int r1 = tgpig.y;
6783    const int im = tgpig.z;
6784
6785    const int first_row = (r0 * NSG + sgitg) * nr0;
6786
6787    const uint i12 = im%args.ne12;
6788    const uint i13 = im/args.ne12;
6789
6790    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6791    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
6792
6793    device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
6794    device const float     * yy = (device const float      *) (src1 + offset1);
6795
6796    float yl[32];
6797
6798    //const uint16_t kmask1 = 0x3030;
6799    //const uint16_t kmask2 = 0x0f0f;
6800
6801    const short tid = tiisg/4;
6802    const short ix  = tiisg%4;
6803    const short ip  = tid/4;          // 0 or 1
6804    const short il  = 2*((tid%4)/2);  // 0 or 2
6805    const short ir  = tid%2;
6806    const short l0  = 8*ir;
6807
6808    // One would think that the Metal compiler would figure out that ip and il can only have
6809    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
6810    // with these two tales.
6811    //
6812    // Possible masks for the high bit
6813    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
6814                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
6815                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
6816                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
6817
6818    // Possible masks for the low 2 bits
6819    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
6820
6821    const ushort4 hm = mm[2*ip + il/2];
6822
6823    const short shift = 2*il;
6824
6825    const float v1 = il == 0 ? 4.f : 64.f;
6826    const float v2 = 4.f * v1;
6827
6828    const uint16_t s_shift1 = 4*ip;
6829    const uint16_t s_shift2 = s_shift1 + il;
6830
6831    const short q_offset = 32*ip + l0;
6832    const short y_offset = 128*ip + 32*il + l0;
6833
6834    device const float * y1 = yy + ix*QK_K + y_offset;
6835
6836    uint32_t scales32, aux32;
6837    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
6838    thread const int8_t * scales = (thread const int8_t *)&scales32;
6839
6840    float sumf1[nr0] = {0.f};
6841    float sumf2[nr0] = {0.f};
6842
6843    for (int i = ix; i < nb; i += 4) {
6844        for (short l = 0; l < 8; ++l) {
6845            yl[l+ 0] = y1[l+ 0];
6846            yl[l+ 8] = y1[l+16];
6847            yl[l+16] = y1[l+32];
6848            yl[l+24] = y1[l+48];
6849        }
6850
6851        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
6852        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
6853        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
6854        device const half * dh = &x[i].d;
6855
6856        for (short row = 0; row < nr0; ++row) {
6857            const float d_all = (float)dh[0];
6858
6859            scales16[0] = a[4];
6860            scales16[1] = a[5];
6861            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
6862            scales16[0] = a[il+0];
6863            scales16[1] = a[il+1];
6864            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
6865
6866            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
6867            for (short l = 0; l < 8; l += 2) {
6868                const int32_t qs = q[l/2];
6869                s1 += yl[l+0] * (qs & qm[il/2][0]);
6870                s2 += yl[l+1] * (qs & qm[il/2][1]);
6871                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
6872                s4 += yl[l+16] * (qs & qm[il/2][2]);
6873                s5 += yl[l+17] * (qs & qm[il/2][3]);
6874                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
6875            }
6876            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
6877            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
6878            sumf1[row] += d1 * (scales[0] - 32);
6879            sumf2[row] += d2 * (scales[2] - 32);
6880
6881            s1 = s2 = s3 = s4 = s5 = s6 = 0;
6882            for (short l = 0; l < 8; l += 2) {
6883                const int32_t qs = q[l/2+8];
6884                s1 += yl[l+8] * (qs & qm[il/2][0]);
6885                s2 += yl[l+9] * (qs & qm[il/2][1]);
6886                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
6887                s4 += yl[l+24] * (qs & qm[il/2][2]);
6888                s5 += yl[l+25] * (qs & qm[il/2][3]);
6889                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
6890            }
6891            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
6892            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
6893            sumf1[row] += d1 * (scales[1] - 32);
6894            sumf2[row] += d2 * (scales[3] - 32);
6895
6896            q  += args.nb01/2;
6897            h  += args.nb01/2;
6898            a  += args.nb01/2;
6899            dh += args.nb01/2;
6900        }
6901
6902        y1 += 4 * QK_K;
6903    }
6904
6905    for (int row = 0; row < nr0; ++row) {
6906        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
6907        sumf1[row] = simd_sum(sumf);
6908    }
6909
6910    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
6911
6912    if (tiisg == 0) {
6913        for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6914            dst_f32[first_row + row] = sumf1[row];
6915        }
6916    }
6917}
6918
6919[[host_name("kernel_mul_mv_q3_K_f32")]]
6920kernel void kernel_mul_mv_q3_K_f32(
6921        constant ggml_metal_kargs_mul_mv & args,
6922        device const char * src0,
6923        device const char * src1,
6924        device       char * dst,
6925        uint3  tgpig[[threadgroup_position_in_grid]],
6926        ushort tiisg[[thread_index_in_simdgroup]],
6927        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6928
6929    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
6930}
6931
6932template<int nr0, typename args_t>
6933void kernel_mul_mv_q4_K_f32_impl(
6934        args_t args,
6935        device const char * src0,
6936        device const char * src1,
6937        device       char * dst,
6938        threadgroup  char * shmem,
6939        uint3  tgpig,
6940        ushort tiisg,
6941        ushort sgitg) {
6942    const short NSG = FC_mul_mv_nsg;
6943
6944    constexpr uint16_t kmask1 = 0x3f3f;
6945    constexpr uint16_t kmask2 = 0x0f0f;
6946    constexpr uint16_t kmask3 = 0xc0c0;
6947
6948    const short ix = tiisg/8;  // 0...3
6949    const short it = tiisg%8;  // 0...7
6950    const short iq = it/4;     // 0 or 1
6951    const short ir = it%4;     // 0...3
6952
6953    const int nb = args.ne00/QK_K;
6954
6955    const int r0 = tgpig.x;
6956    const int r1 = tgpig.y;
6957    const int im = tgpig.z;
6958
6959    const int first_row = (r0 * NSG + sgitg) * nr0;
6960
6961    const uint i12 = im%args.ne12;
6962    const uint i13 = im/args.ne12;
6963
6964    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6965    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
6966
6967    device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
6968    device const float      * y = (device const float      *) (src1 + offset1);
6969
6970    float yl[16];
6971    float yh[16];
6972
6973    float sumf[nr0]={0.f};
6974
6975    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
6976
6977    uint16_t sc16[4];
6978    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
6979
6980    for (int ib = ix; ib < nb; ib += 4) {
6981        float4 sumy = {0.f, 0.f, 0.f, 0.f};
6982
6983        for (short i = 0; i < 8; ++i) {
6984            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
6985            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
6986            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
6987            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
6988        }
6989
6990        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
6991        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
6992        device const half     * dh = &x[ib].d;
6993
6994        for (short row = 0; row < nr0; row++) {
6995            sc16[0] = sc[0] & kmask1;
6996            sc16[1] = sc[2] & kmask1;
6997            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
6998            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
6999
7000            device const uint16_t * q2 = q1 + 32;
7001
7002            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
7003            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
7004
7005            FOR_UNROLL (short i = 0; i < 4; ++i) {
7006                acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
7007                acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
7008                acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
7009                acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
7010                acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
7011                acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
7012                acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
7013                acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
7014            }
7015
7016            sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
7017                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
7018                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
7019                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
7020                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
7021
7022            q1 += args.nb01/2;
7023            sc += args.nb01/2;
7024            dh += args.nb01/2;
7025        }
7026
7027        y4 += 4 * QK_K;
7028    }
7029
7030    device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
7031
7032    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7033        float sum_all = simd_sum(sumf[row]);
7034        if (tiisg == 0) {
7035            dst_f32[first_row + row] = sum_all;
7036        }
7037    }
7038}
7039
7040[[host_name("kernel_mul_mv_q4_K_f32")]]
7041kernel void kernel_mul_mv_q4_K_f32(
7042        constant ggml_metal_kargs_mul_mv & args,
7043        device const char * src0,
7044        device const char * src1,
7045        device       char * dst,
7046        uint3  tgpig[[threadgroup_position_in_grid]],
7047        ushort tiisg[[thread_index_in_simdgroup]],
7048        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7049
7050    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7051}
7052
7053template<int nr0, typename args_t>
7054void kernel_mul_mv_q5_K_f32_impl(
7055        args_t args,
7056        device const char * src0,
7057        device const char * src1,
7058        device       char * dst,
7059        threadgroup  char * shmem,
7060        uint3  tgpig,
7061        ushort tiisg,
7062        ushort sgitg) {
7063    const short NSG = FC_mul_mv_nsg;
7064
7065    const int nb = args.ne00/QK_K;
7066
7067    const int r0 = tgpig.x;
7068    const int r1 = tgpig.y;
7069    const int im = tgpig.z;
7070
7071    const int first_row = (r0 * NSG + sgitg) * nr0;
7072
7073    const uint i12 = im%args.ne12;
7074    const uint i13 = im/args.ne12;
7075
7076    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7077    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7078
7079    device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
7080    device const float     * yy = (device const float      *) (src1 + offset1);
7081
7082    float sumf[nr0]={0.f};
7083
7084    float yl[16], yh[16];
7085
7086    constexpr uint16_t kmask1 = 0x3f3f;
7087    constexpr uint16_t kmask2 = 0x0f0f;
7088    constexpr uint16_t kmask3 = 0xc0c0;
7089
7090    const short tid = tiisg/4;
7091    const short ix  = tiisg%4;
7092    const short iq  = tid/4;
7093    const short ir  = tid%4;
7094
7095    const short l0 = 8*ir;
7096    const short q_offset = 32*iq + l0;
7097    const short y_offset = 64*iq + l0;
7098
7099    const uint8_t hm1 = 1u << (2*iq);
7100    const uint8_t hm2 = hm1 << 1;
7101    const uint8_t hm3 = hm1 << 4;
7102    const uint8_t hm4 = hm2 << 4;
7103
7104    uint16_t sc16[4];
7105    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
7106
7107    device const float * y1 = yy + ix*QK_K + y_offset;
7108
7109    for (int i = ix; i < nb; i += 4) {
7110        device const uint8_t * q1 = x[i].qs + q_offset;
7111        device const uint8_t * qh = x[i].qh + l0;
7112        device const half * dh = &x[i].d;
7113        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
7114
7115        device const float * y2 = y1 + 128;
7116        float4 sumy = {0.f, 0.f, 0.f, 0.f};
7117        for (short l = 0; l < 8; ++l) {
7118            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
7119            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
7120            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
7121            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
7122        }
7123
7124        for (short row = 0; row < nr0; ++row) {
7125            device const uint8_t * q2 = q1 + 64;
7126
7127            sc16[0] = a[0] & kmask1;
7128            sc16[1] = a[2] & kmask1;
7129            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
7130            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
7131
7132            float4 acc1 = {0.f};
7133            float4 acc2 = {0.f};
7134            FOR_UNROLL (short l = 0; l < 8; ++l) {
7135                uint8_t h = qh[l];
7136                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
7137                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
7138                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
7139                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
7140                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
7141                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
7142                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
7143                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
7144            }
7145
7146            sumf[row] += dh[0] * (sc8[0] * (acc1[0]      + 16.f*acc2[0]) +
7147                                  sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
7148                                  sc8[4] * (acc1[2]      + 16.f*acc2[2]) +
7149                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
7150                         dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
7151
7152            q1 += args.nb01;
7153            qh += args.nb01;
7154            dh += args.nb01/2;
7155            a  += args.nb01/2;
7156        }
7157
7158        y1 += 4 * QK_K;
7159    }
7160
7161    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7162
7163    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7164        const float tot = simd_sum(sumf[row]);
7165        if (tiisg == 0) {
7166            dst_f32[first_row + row] = tot;
7167        }
7168    }
7169}
7170
7171[[host_name("kernel_mul_mv_q5_K_f32")]]
7172kernel void kernel_mul_mv_q5_K_f32(
7173        constant ggml_metal_kargs_mul_mv & args,
7174        device const char * src0,
7175        device const char * src1,
7176        device       char * dst,
7177        uint3  tgpig[[threadgroup_position_in_grid]],
7178        ushort tiisg[[thread_index_in_simdgroup]],
7179        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7180
7181    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7182}
7183
7184template<int nr0, typename args_t>
7185void kernel_mul_mv_q6_K_f32_impl(
7186        args_t args,
7187        device const char * src0,
7188        device const char * src1,
7189        device       char * dst,
7190        threadgroup  char * shmem,
7191        uint3  tgpig,
7192        ushort tiisg,
7193        ushort sgitg) {
7194    const short NSG = FC_mul_mv_nsg;
7195
7196    constexpr uint8_t kmask1 = 0x03;
7197    constexpr uint8_t kmask2 = 0x0C;
7198    constexpr uint8_t kmask3 = 0x30;
7199    constexpr uint8_t kmask4 = 0xC0;
7200
7201    const int nb = args.ne00/QK_K;
7202
7203    const int r0 = tgpig.x;
7204    const int r1 = tgpig.y;
7205    const int im = tgpig.z;
7206
7207    const int first_row = (r0 * NSG + sgitg) * nr0;
7208
7209    const uint i12 = im%args.ne12;
7210    const uint i13 = im/args.ne12;
7211
7212    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7213    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7214
7215    device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
7216    device const float     * yy = (device const float      *) (src1 + offset1);
7217
7218    float sumf[nr0] = { 0.f };
7219
7220    float yl[16];
7221
7222    const short tid = tiisg/2;
7223    const short ix  = tiisg%2;
7224    const short ip  = tid/8;         // 0 or 1
7225    const short il  = tid%8;
7226    const short l0  = 4*il;
7227    const short is  = 8*ip + l0/16;
7228
7229    const short y_offset   = 128*ip + l0;
7230    const short q_offset_l =  64*ip + l0;
7231    const short q_offset_h =  32*ip + l0;
7232
7233    for (int i = ix; i < nb; i += 2) {
7234        device const uint8_t * q1 = x[i].ql + q_offset_l;
7235        device const uint8_t * q2 = q1 + 32;
7236        device const uint8_t * qh = x[i].qh + q_offset_h;
7237        device const int8_t  * sc = x[i].scales + is;
7238        device const half    * dh = &x[i].d;
7239
7240        device const float * y = yy + i * QK_K + y_offset;
7241
7242        for (short l = 0; l < 4; ++l) {
7243            yl[4*l + 0] = y[l +  0];
7244            yl[4*l + 1] = y[l + 32];
7245            yl[4*l + 2] = y[l + 64];
7246            yl[4*l + 3] = y[l + 96];
7247        }
7248
7249        for (short row = 0; row < nr0; ++row) {
7250            float4 sums = {0.f, 0.f, 0.f, 0.f};
7251
7252            FOR_UNROLL (short l = 0; l < 4; ++l) {
7253                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
7254                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
7255                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
7256                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
7257            }
7258
7259            sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
7260
7261            q1 += args.nb01;
7262            q2 += args.nb01;
7263            qh += args.nb01;
7264            sc += args.nb01;
7265            dh += args.nb01/2;
7266        }
7267    }
7268
7269    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7270
7271    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7272        float sum_all = simd_sum(sumf[row]);
7273        if (tiisg == 0) {
7274            dst_f32[first_row + row] = sum_all;
7275        }
7276    }
7277}
7278
7279[[host_name("kernel_mul_mv_q6_K_f32")]]
7280kernel void kernel_mul_mv_q6_K_f32(
7281        constant ggml_metal_kargs_mul_mv & args,
7282        device const char * src0,
7283        device const char * src1,
7284        device       char * dst,
7285        uint3  tgpig[[threadgroup_position_in_grid]],
7286        ushort tiisg[[thread_index_in_simdgroup]],
7287        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7288
7289    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7290}
7291
7292// ======================= "True" 2-bit
7293
7294template<int nr0, typename args_t>
7295void kernel_mul_mv_iq2_xxs_f32_impl(
7296        args_t args,
7297        device const char * src0,
7298        device const char * src1,
7299        device       char * dst,
7300        threadgroup  char * shmem,
7301        uint3  tgpig,
7302        ushort tiisg,
7303        ushort sgitg) {
7304    const short NSG = FC_mul_mv_nsg;
7305
7306    const int nb = args.ne00/QK_K;
7307
7308    const int r0 = tgpig.x;
7309    const int r1 = tgpig.y;
7310    const int im = tgpig.z;
7311
7312    const int first_row = (r0 * NSG + sgitg) * nr0;
7313
7314    const uint i12 = im%args.ne12;
7315    const uint i13 = im/args.ne12;
7316
7317    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7318    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7319
7320    device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
7321    device const float         * y = (device const float         *) (src1 + offset1);
7322
7323    float yl[32];
7324    float sumf[nr0]={0.f};
7325
7326    const int nb32 = nb * (QK_K / 32);
7327
7328    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
7329    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
7330    {
7331        int nval = 4;
7332        int pos  = (32*sgitg + tiisg)*nval;
7333        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
7334        nval = 2;
7335        pos  = (32*sgitg + tiisg)*nval;
7336        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7337        threadgroup_barrier(mem_flags::mem_threadgroup);
7338    }
7339
7340    const int ix = tiisg;
7341
7342    device const float * y4 = y + 32 * ix;
7343
7344    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7345        for (short i = 0; i < 32; ++i) {
7346            yl[i] = y4[i];
7347        }
7348
7349        const int ibl = ib32 / (QK_K / 32);
7350        const int ib  = ib32 % (QK_K / 32);
7351
7352        device const block_iq2_xxs * xr = x + ibl;
7353        device const uint16_t * q2 = xr->qs + 4 * ib;
7354        device const half * dh = &xr->d;
7355
7356        for (short row = 0; row < nr0; row++) {
7357            const float db = dh[0];
7358            device const uint8_t * aux8 = (device const uint8_t *)q2;
7359            const uint32_t aux32 = q2[2] | (q2[3] << 16);
7360            const float d = db * (0.5f + (aux32 >> 28));
7361
7362            float sum = 0;
7363            for (short l = 0; l < 4; ++l) {
7364                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
7365                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
7366                for (short j = 0; j < 8; ++j) {
7367                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7368                }
7369            }
7370            sumf[row] += d * sum;
7371
7372            dh += args.nb01/2;
7373            q2 += args.nb01/2;
7374        }
7375
7376        y4 += 32 * 32;
7377    }
7378
7379    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7380
7381    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7382        float sum_all = simd_sum(sumf[row]);
7383        if (tiisg == 0) {
7384            dst_f32[first_row + row] = sum_all * 0.25f;
7385        }
7386    }
7387}
7388
7389[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
7390kernel void kernel_mul_mv_iq2_xxs_f32(
7391        constant ggml_metal_kargs_mul_mv & args,
7392        device const char * src0,
7393        device const char * src1,
7394        device       char * dst,
7395        threadgroup  char * shmem [[threadgroup(0)]],
7396        uint3  tgpig[[threadgroup_position_in_grid]],
7397        ushort tiisg[[thread_index_in_simdgroup]],
7398        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7399    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7400}
7401
7402template<int nr0, typename args_t>
7403void kernel_mul_mv_iq2_xs_f32_impl(
7404        args_t args,
7405        device const char * src0,
7406        device const char * src1,
7407        device       char * dst,
7408        threadgroup  char * shmem,
7409        uint3  tgpig,
7410        ushort tiisg,
7411        ushort sgitg) {
7412    const short NSG = FC_mul_mv_nsg;
7413
7414    const int nb = args.ne00/QK_K;
7415
7416    const int r0 = tgpig.x;
7417    const int r1 = tgpig.y;
7418    const int im = tgpig.z;
7419
7420    const int first_row = (r0 * NSG + sgitg) * nr0;
7421
7422    const uint i12 = im%args.ne12;
7423    const uint i13 = im/args.ne12;
7424
7425    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7426    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7427
7428    device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
7429    device const float        * y = (device const float        *) (src1 + offset1);
7430
7431    float yl[32];
7432    float sumf[nr0]={0.f};
7433
7434    const int nb32 = nb * (QK_K / 32);
7435
7436    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
7437    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 512);
7438    {
7439        int nval = 8;
7440        int pos  = (32*sgitg + tiisg)*nval;
7441        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
7442        nval = 2;
7443        pos  = (32*sgitg + tiisg)*nval;
7444        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7445        threadgroup_barrier(mem_flags::mem_threadgroup);
7446    }
7447
7448    const int ix = tiisg;
7449
7450    device const float * y4 = y + 32 * ix;
7451
7452    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7453        for (short i = 0; i < 32; ++i) {
7454            yl[i] = y4[i];
7455        }
7456
7457        const int ibl = ib32 / (QK_K / 32);
7458        const int ib  = ib32 % (QK_K / 32);
7459
7460        device const block_iq2_xs * xr = x + ibl;
7461        device const uint16_t * q2 = xr->qs + 4 * ib;
7462        device const uint8_t  * sc = xr->scales + ib;
7463        device const half * dh = &xr->d;
7464
7465        for (short row = 0; row < nr0; row++) {
7466            const float db = dh[0];
7467            const uint8_t ls1 = sc[0] & 0xf;
7468            const uint8_t ls2 = sc[0] >>  4;
7469            const float d1 = db * (0.5f + ls1);
7470            const float d2 = db * (0.5f + ls2);
7471
7472            float sum1 = 0, sum2 = 0;
7473            for (short l = 0; l < 2; ++l) {
7474                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
7475                const uint8_t signs = ssigns[(q2[l] >> 9)];
7476                for (short j = 0; j < 8; ++j) {
7477                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7478                }
7479            }
7480            for (short l = 2; l < 4; ++l) {
7481                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
7482                const uint8_t signs = ssigns[(q2[l] >> 9)];
7483                for (short j = 0; j < 8; ++j) {
7484                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7485                }
7486            }
7487            sumf[row] += d1 * sum1 + d2 * sum2;
7488
7489            dh += args.nb01/2;
7490            q2 += args.nb01/2;
7491            sc += args.nb01;
7492        }
7493
7494        y4 += 32 * 32;
7495    }
7496
7497    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7498
7499    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7500        float sum_all = simd_sum(sumf[row]);
7501        if (tiisg == 0) {
7502            dst_f32[first_row + row] = sum_all * 0.25f;
7503        }
7504    }
7505}
7506
7507[[host_name("kernel_mul_mv_iq2_xs_f32")]]
7508kernel void kernel_mul_mv_iq2_xs_f32(
7509        constant ggml_metal_kargs_mul_mv & args,
7510        device const char * src0,
7511        device const char * src1,
7512        device       char * dst,
7513        threadgroup  char * shmem [[threadgroup(0)]],
7514        uint3  tgpig[[threadgroup_position_in_grid]],
7515        ushort tiisg[[thread_index_in_simdgroup]],
7516        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7517
7518    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7519}
7520
7521template<int nr0, typename args_t>
7522void kernel_mul_mv_iq3_xxs_f32_impl(
7523        args_t args,
7524        device const char * src0,
7525        device const char * src1,
7526        device       char * dst,
7527        threadgroup  char * shmem,
7528        uint3  tgpig,
7529        ushort tiisg,
7530        ushort sgitg) {
7531    const short NSG = FC_mul_mv_nsg;
7532
7533    const int nb = args.ne00/QK_K;
7534
7535    const int r0 = tgpig.x;
7536    const int r1 = tgpig.y;
7537    const int im = tgpig.z;
7538
7539    const int first_row = (r0 * NSG + sgitg) * nr0;
7540
7541    const uint i12 = im%args.ne12;
7542    const uint i13 = im/args.ne12;
7543
7544    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7545    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7546
7547    device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
7548    device const float         * y = (device const float         *) (src1 + offset1);
7549
7550    float yl[32];
7551    float sumf[nr0]={0.f};
7552
7553    const int nb32 = nb * (QK_K / 32);
7554
7555    threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
7556    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
7557    {
7558        int nval = 4;
7559        int pos  = (32*sgitg + tiisg)*nval;
7560        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
7561        nval = 2;
7562        pos  = (32*sgitg + tiisg)*nval;
7563        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7564        threadgroup_barrier(mem_flags::mem_threadgroup);
7565    }
7566
7567    const int ix = tiisg;
7568
7569    device const float * y4 = y + 32 * ix;
7570
7571    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7572        for (short i = 0; i < 32; ++i) {
7573            yl[i] = y4[i];
7574        }
7575
7576        const int ibl = ib32 / (QK_K / 32);
7577        const int ib  = ib32 % (QK_K / 32);
7578
7579        device const block_iq3_xxs * xr = x + ibl;
7580        device const uint8_t  * q3 = xr->qs + 8 * ib;
7581        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
7582        device const half * dh = &xr->d;
7583
7584        for (short row = 0; row < nr0; row++) {
7585            const float db = dh[0];
7586            const uint32_t aux32 = gas[0] | (gas[1] << 16);
7587            const float d = db * (0.5f + (aux32 >> 28));
7588
7589            float2 sum = {0};
7590            for (short l = 0; l < 4; ++l) {
7591                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
7592                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
7593                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
7594                for (short j = 0; j < 4; ++j) {
7595                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
7596                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
7597                }
7598            }
7599            sumf[row] += d * (sum[0] + sum[1]);
7600
7601            dh  += args.nb01/2;
7602            q3  += args.nb01;
7603            gas += args.nb01/2;
7604        }
7605
7606        y4 += 32 * 32;
7607    }
7608
7609    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7610
7611    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7612        float sum_all = simd_sum(sumf[row]);
7613        if (tiisg == 0) {
7614            dst_f32[first_row + row] = sum_all * 0.5f;
7615        }
7616    }
7617}
7618
7619[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
7620kernel void kernel_mul_mv_iq3_xxs_f32(
7621        constant ggml_metal_kargs_mul_mv & args,
7622        device const char * src0,
7623        device const char * src1,
7624        device       char * dst,
7625        threadgroup  char * shmem [[threadgroup(0)]],
7626        uint3  tgpig[[threadgroup_position_in_grid]],
7627        ushort tiisg[[thread_index_in_simdgroup]],
7628        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7629
7630    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7631}
7632
7633template<int nr0, typename args_t>
7634void kernel_mul_mv_iq3_s_f32_impl(
7635        args_t args,
7636        device const char * src0,
7637        device const char * src1,
7638        device       char * dst,
7639        threadgroup  char * shmem,
7640        uint3  tgpig,
7641        ushort tiisg,
7642        ushort sgitg) {
7643    const short NSG = FC_mul_mv_nsg;
7644
7645    const int nb = args.ne00/QK_K;
7646
7647    const int r0 = tgpig.x;
7648    const int r1 = tgpig.y;
7649    const int im = tgpig.z;
7650
7651    const int first_row = (r0 * NSG + sgitg) * nr0;
7652
7653    const uint i12 = im%args.ne12;
7654    const uint i13 = im/args.ne12;
7655
7656    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7657    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7658
7659    device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
7660    device const float       * y = (device const float       *) (src1 + offset1);
7661
7662    float yl[32];
7663    float sumf[nr0]={0.f};
7664
7665    const int nb32 = nb * (QK_K / 32);
7666
7667    threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
7668    {
7669        int nval = 8;
7670        int pos  = (32*sgitg + tiisg)*nval;
7671        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
7672        threadgroup_barrier(mem_flags::mem_threadgroup);
7673    }
7674
7675    const int ix = tiisg;
7676
7677    device const float * y4 = y + 32 * ix;
7678
7679    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7680        for (short i = 0; i < 32; ++i) {
7681            yl[i] = y4[i];
7682        }
7683
7684        const int ibl = ib32 / (QK_K / 32);
7685        const int ib  = ib32 % (QK_K / 32);
7686
7687        device const block_iq3_s * xr = x + ibl;
7688        device const uint8_t * qs = xr->qs + 8 * ib;
7689        device const uint8_t * qh = xr->qh + ib;
7690        device const uint8_t * sc = xr->scales + (ib/2);
7691        device const uint8_t * signs = xr->signs + 4 * ib;
7692        device const half * dh = &xr->d;
7693
7694        for (short row = 0; row < nr0; row++) {
7695            const float db = dh[0];
7696            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
7697
7698            float2 sum = {0};
7699            for (short l = 0; l < 4; ++l) {
7700                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
7701                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
7702                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
7703                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
7704                for (short j = 0; j < 4; ++j) {
7705                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
7706                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
7707                }
7708            }
7709            sumf[row] += d * (sum[0] + sum[1]);
7710
7711            dh    += args.nb01/2;
7712            qs    += args.nb01;
7713            qh    += args.nb01;
7714            sc    += args.nb01;
7715            signs += args.nb01;
7716        }
7717
7718        y4 += 32 * 32;
7719    }
7720
7721    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7722
7723    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7724        float sum_all = simd_sum(sumf[row]);
7725        if (tiisg == 0) {
7726            dst_f32[first_row + row] = sum_all;
7727        }
7728    }
7729}
7730
7731[[host_name("kernel_mul_mv_iq3_s_f32")]]
7732kernel void kernel_mul_mv_iq3_s_f32(
7733        constant ggml_metal_kargs_mul_mv & args,
7734        device const char * src0,
7735        device const char * src1,
7736        device       char * dst,
7737        threadgroup  char * shmem [[threadgroup(0)]],
7738        uint3  tgpig[[threadgroup_position_in_grid]],
7739        ushort tiisg[[thread_index_in_simdgroup]],
7740        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7741
7742    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7743}
7744
7745template<int nr0, typename args_t>
7746void kernel_mul_mv_iq2_s_f32_impl(
7747        args_t args,
7748        device const char * src0,
7749        device const char * src1,
7750        device       char * dst,
7751        threadgroup  char * shmem,
7752        uint3  tgpig,
7753        ushort tiisg,
7754        ushort sgitg) {
7755    const short NSG = FC_mul_mv_nsg;
7756
7757    const int nb = args.ne00/QK_K;
7758
7759    const int r0 = tgpig.x;
7760    const int r1 = tgpig.y;
7761    const int im = tgpig.z;
7762
7763    const int first_row = (r0 * NSG + sgitg) * nr0;
7764
7765    const uint i12 = im%args.ne12;
7766    const uint i13 = im/args.ne12;
7767
7768    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7769    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7770
7771    device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
7772    device const float       * y = (device const float       *) (src1 + offset1);
7773
7774    float yl[32];
7775    float sumf[nr0]={0.f};
7776
7777    const int nb32 = nb * (QK_K / 32);
7778
7779    //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
7780    //{
7781    //    int nval = 32;
7782    //    int pos  = (32*sgitg + tiisg)*nval;
7783    //    for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
7784    //    threadgroup_barrier(mem_flags::mem_threadgroup);
7785    //}
7786
7787    const short ix = tiisg;
7788
7789    device const float * y4 = y + 32 * ix;
7790
7791    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7792        for (short i = 0; i < 32; ++i) {
7793            yl[i] = y4[i];
7794        }
7795
7796        const int ibl = ib32 / (QK_K / 32);
7797        const int ib  = ib32 % (QK_K / 32);
7798
7799        device const block_iq2_s * xr = x + ibl;
7800        device const uint8_t * qs = xr->qs + 4 * ib;
7801        device const uint8_t * qh = xr->qh + ib;
7802        device const uint8_t * sc = xr->scales + ib;
7803        device const uint8_t * signs = qs + QK_K/8;
7804        device const half * dh = &xr->d;
7805
7806        for (short row = 0; row < nr0; row++) {
7807            const float db = dh[0];
7808            const float d1 = db * (0.5f + (sc[0] & 0xf));
7809            const float d2 = db * (0.5f + (sc[0] >>  4));
7810
7811            float2 sum = {0};
7812            for (short l = 0; l < 2; ++l) {
7813                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
7814                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
7815                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
7816                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
7817                for (short j = 0; j < 8; ++j) {
7818                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
7819                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
7820                }
7821            }
7822            sumf[row] += d1 * sum[0] + d2 * sum[1];
7823
7824            dh    += args.nb01/2;
7825            qs    += args.nb01;
7826            qh    += args.nb01;
7827            sc    += args.nb01;
7828            signs += args.nb01;
7829        }
7830
7831        y4 += 32 * 32;
7832    }
7833
7834    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7835
7836    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7837        float sum_all = simd_sum(sumf[row]);
7838        if (tiisg == 0) {
7839            dst_f32[first_row + row] = sum_all * 0.25f;
7840        }
7841    }
7842}
7843
7844[[host_name("kernel_mul_mv_iq2_s_f32")]]
7845kernel void kernel_mul_mv_iq2_s_f32(
7846        constant ggml_metal_kargs_mul_mv & args,
7847        device const char * src0,
7848        device const char * src1,
7849        device       char * dst,
7850        threadgroup  char * shmem [[threadgroup(0)]],
7851        uint3  tgpig[[threadgroup_position_in_grid]],
7852        ushort tiisg[[thread_index_in_simdgroup]],
7853        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7854
7855    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7856}
7857
7858template<int nr0, typename args_t>
7859void kernel_mul_mv_iq1_s_f32_impl(
7860        args_t args,
7861        device const char * src0,
7862        device const char * src1,
7863        device       char * dst,
7864        threadgroup  char * shmem,
7865        uint3  tgpig,
7866        ushort tiisg,
7867        ushort sgitg) {
7868    const short NSG = FC_mul_mv_nsg;
7869
7870    const int nb = args.ne00/QK_K;
7871
7872    const int r0 = tgpig.x;
7873    const int r1 = tgpig.y;
7874    const int im = tgpig.z;
7875
7876    const int first_row = (r0 * NSG + sgitg) * nr0;
7877
7878    const uint i12 = im%args.ne12;
7879    const uint i13 = im/args.ne12;
7880
7881    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7882    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7883
7884    device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
7885    device const float       * y = (device const float       *) (src1 + offset1);
7886
7887    float yl[32];
7888    float sumf[nr0]={0.f};
7889
7890    const int nb32 = nb * (QK_K / 32);
7891
7892    const short ix = tiisg;
7893
7894    device const float * y4 = y + 32 * ix;
7895
7896    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7897        float sumy = 0;
7898        for (short i = 0; i < 32; ++i) {
7899            yl[i] = y4[i];
7900            sumy += yl[i];
7901        }
7902
7903        const int ibl = ib32 / (QK_K / 32);
7904        const int ib  = ib32 % (QK_K / 32);
7905
7906        device const block_iq1_s * xr = x + ibl;
7907        device const uint8_t  * qs = xr->qs + 4 * ib;
7908        device const uint16_t * qh = xr->qh + ib;
7909        device const half     * dh = &xr->d;
7910
7911        for (short row = 0; row < nr0; row++) {
7912            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
7913            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
7914            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
7915            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
7916
7917            float sum = 0;
7918            for (short j = 0; j < 4; ++j) {
7919                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
7920                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
7921                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
7922                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
7923            }
7924            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
7925
7926            dh += args.nb01/2;
7927            qs += args.nb01;
7928            qh += args.nb01/2;
7929        }
7930
7931        y4 += 32 * 32;
7932    }
7933
7934    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7935
7936    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7937        float sum_all = simd_sum(sumf[row]);
7938        if (tiisg == 0) {
7939            dst_f32[first_row + row] = sum_all;
7940        }
7941    }
7942}
7943
7944[[host_name("kernel_mul_mv_iq1_s_f32")]]
7945kernel void kernel_mul_mv_iq1_s_f32(
7946        constant ggml_metal_kargs_mul_mv & args,
7947        device const char * src0,
7948        device const char * src1,
7949        device       char * dst,
7950        uint3  tgpig[[threadgroup_position_in_grid]],
7951        ushort tiisg[[thread_index_in_simdgroup]],
7952        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7953
7954    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7955}
7956
7957template<int nr0, typename args_t>
7958void kernel_mul_mv_iq1_m_f32_impl(
7959        args_t args,
7960        device const char * src0,
7961        device const char * src1,
7962        device       char * dst,
7963        threadgroup  char * shmem,
7964        uint3  tgpig,
7965        ushort tiisg,
7966        ushort sgitg) {
7967    const short NSG = FC_mul_mv_nsg;
7968
7969    const int nb = args.ne00/QK_K;
7970
7971    const int r0 = tgpig.x;
7972    const int r1 = tgpig.y;
7973    const int im = tgpig.z;
7974
7975    const int first_row = (r0 * NSG + sgitg) * nr0;
7976
7977    const uint i12 = im%args.ne12;
7978    const uint i13 = im/args.ne12;
7979
7980    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7981    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
7982
7983    device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
7984    device const float       * y = (device const float       *) (src1 + offset1);
7985
7986    float yl[32];
7987    float sumf[nr0]={0.f};
7988
7989    const int nb32 = nb * (QK_K / 32);
7990
7991    const short ix = tiisg;
7992
7993    device const float * y4 = y + 32 * ix;
7994
7995    iq1m_scale_t scale;
7996
7997    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7998        float4 sumy = {0.f};
7999        for (short i = 0; i < 8; ++i) {
8000            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
8001            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
8002            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
8003            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
8004        }
8005
8006        const int ibl = ib32 / (QK_K / 32);
8007        const int ib  = ib32 % (QK_K / 32);
8008
8009        device const block_iq1_m * xr = x + ibl;
8010        device const uint8_t  * qs = xr->qs + 4 * ib;
8011        device const uint8_t  * qh = xr->qh + 2 * ib;
8012        device const uint16_t * sc = (device const uint16_t *)xr->scales;
8013
8014        for (short row = 0; row < nr0; row++) {
8015            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
8016
8017            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
8018            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
8019            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
8020            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
8021
8022            float2 sum = {0.f};
8023            for (short j = 0; j < 4; ++j) {
8024                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
8025                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
8026                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
8027                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
8028            }
8029            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
8030            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
8031
8032            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
8033                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
8034
8035            sc += args.nb01/2;
8036            qs += args.nb01;
8037            qh += args.nb01;
8038        }
8039
8040        y4 += 32 * 32;
8041    }
8042
8043    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8044
8045    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
8046        float sum_all = simd_sum(sumf[row]);
8047        if (tiisg == 0) {
8048            dst_f32[first_row + row] = sum_all;
8049        }
8050    }
8051}
8052
8053[[host_name("kernel_mul_mv_iq1_m_f32")]]
8054kernel void kernel_mul_mv_iq1_m_f32(
8055        constant ggml_metal_kargs_mul_mv & args,
8056        device const char * src0,
8057        device const char * src1,
8058        device       char * dst,
8059        uint3  tgpig[[threadgroup_position_in_grid]],
8060        ushort tiisg[[thread_index_in_simdgroup]],
8061        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8062
8063    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
8064}
8065
8066template<int NR0, typename args_t>
8067void kernel_mul_mv_iq4_nl_f32_impl(
8068        args_t args,
8069        device const char * src0,
8070        device const char * src1,
8071        device       char * dst,
8072        threadgroup  char * shmem,
8073        uint3  tgpig,
8074        ushort tiisg,
8075        ushort sgitg) {
8076    const short NSG = FC_mul_mv_nsg;
8077
8078    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8079
8080    const int r0 = tgpig.x;
8081    const int r1 = tgpig.y;
8082    const int im = tgpig.z;
8083
8084    const int first_row = (r0 * NSG + sgitg) * NR0;
8085
8086    const uint i12 = im%args.ne12;
8087    const uint i13 = im/args.ne12;
8088
8089    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8090    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
8091
8092    device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
8093    device const float        * y = (device const float        *) (src1 + offset1);
8094
8095    const int nb   = args.ne00/QK4_NL;
8096    const int ns01 = args.nb01/args.nb00;
8097
8098    const short ix = tiisg/2;  // 0...15
8099    const short it = tiisg%2;  // 0 or 1
8100
8101    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
8102    threadgroup_barrier(mem_flags::mem_threadgroup);
8103
8104    float4 yl[4];
8105    float sumf[NR0]={0.f};
8106
8107    device const float * yb = y + ix*QK4_NL + it*8;
8108
8109    uint32_t aux32[2];
8110    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
8111
8112    float4 qf1, qf2;
8113
8114    // [TAG_MUL_MV_WEIRD]
8115    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
8116        device const float4 * y4 = (device const float4 *)yb;
8117        yl[0] = y4[0];
8118        yl[1] = y4[4];
8119        yl[2] = y4[1];
8120        yl[3] = y4[5];
8121
8122        for (short row = 0; row < NR0; row++) {
8123            device const block_iq4_nl & xb = x[row*ns01 + ib];
8124            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
8125
8126            float4 acc1 = {0.f}, acc2 = {0.f};
8127
8128            aux32[0] = q4[0] | (q4[1] << 16);
8129            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
8130            aux32[0] &= 0x0f0f0f0f;
8131            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8132            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8133            acc1 += yl[0] * qf1;
8134            acc2 += yl[1] * qf2;
8135
8136            aux32[0] = q4[2] | (q4[3] << 16);
8137            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
8138            aux32[0] &= 0x0f0f0f0f;
8139            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8140            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8141            acc1 += yl[2] * qf1;
8142            acc2 += yl[3] * qf2;
8143
8144            acc1 += acc2;
8145
8146            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
8147        }
8148
8149        yb += 16 * QK4_NL;
8150    }
8151
8152    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8153
8154    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8155        float sum_all = simd_sum(sumf[row]);
8156        if (tiisg == 0) {
8157            dst_f32[first_row + row] = sum_all;
8158        }
8159    }
8160}
8161
8162[[host_name("kernel_mul_mv_iq4_nl_f32")]]
8163kernel void kernel_mul_mv_iq4_nl_f32(
8164        constant ggml_metal_kargs_mul_mv & args,
8165        device const char * src0,
8166        device const char * src1,
8167        device       char * dst,
8168        threadgroup  char * shmem [[threadgroup(0)]],
8169        uint3  tgpig[[threadgroup_position_in_grid]],
8170        ushort tiisg[[thread_index_in_simdgroup]],
8171        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8172
8173    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8174}
8175
8176template<int NR0, typename args_t>
8177void kernel_mul_mv_iq4_xs_f32_impl(
8178        args_t args,
8179        device const char * src0,
8180        device const char * src1,
8181        device       char * dst,
8182        threadgroup  char * shmem,
8183        uint3  tgpig,
8184        ushort tiisg,
8185        ushort sgitg) {
8186    const short NSG = FC_mul_mv_nsg;
8187
8188    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8189
8190    const int r0 = tgpig.x;
8191    const int r1 = tgpig.y;
8192    const int im = tgpig.z;
8193    const int first_row = (r0 * NSG + sgitg) * NR0;
8194
8195    const uint i12 = im%args.ne12;
8196    const uint i13 = im/args.ne12;
8197
8198    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8199    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
8200
8201    device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
8202    device const float        * y = (device const float        *) (src1 + offset1);
8203
8204    const int nb   = args.ne00/QK_K;
8205    const int ns01 = args.nb01/args.nb00;
8206
8207    const short ix = tiisg/16;  // 0 or 1
8208    const short it = tiisg%16;  // 0...15
8209    const short ib = it/2;
8210    const short il = it%2;
8211
8212    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
8213    threadgroup_barrier(mem_flags::mem_threadgroup);
8214
8215    float4 yl[4];
8216    float sumf[NR0]={0.f};
8217
8218    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
8219
8220    uint32_t aux32[2];
8221    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
8222
8223    float4 qf1, qf2;
8224
8225    // [TAG_MUL_MV_WEIRD]
8226    for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
8227        device const float4 * y4 = (device const float4 *)yb;
8228        yl[0] = y4[0];
8229        yl[1] = y4[4];
8230        yl[2] = y4[1];
8231        yl[3] = y4[5];
8232
8233        for (short row = 0; row < NR0; ++row) {
8234            device const block_iq4_xs & xb = x[row*ns01 + ibl];
8235            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
8236
8237            float4 acc1 = {0.f}, acc2 = {0.f};
8238
8239            aux32[0] = (q4[0]     ) & 0x0f0f0f0f;
8240            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
8241            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8242            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8243            acc1 += yl[0] * qf1;
8244            acc2 += yl[1] * qf2;
8245
8246            aux32[0] = (q4[1]     ) & 0x0f0f0f0f;
8247            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
8248            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8249            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8250            acc1 += yl[2] * qf1;
8251            acc2 += yl[3] * qf2;
8252
8253            acc1 += acc2;
8254
8255            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
8256            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
8257        }
8258
8259        yb += 2 * QK_K;
8260    }
8261
8262    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8263
8264    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8265        float sum_all = simd_sum(sumf[row]);
8266        if (tiisg == 0) {
8267            dst_f32[first_row + row] = sum_all;
8268        }
8269    }
8270}
8271
8272[[host_name("kernel_mul_mv_iq4_xs_f32")]]
8273kernel void kernel_mul_mv_iq4_xs_f32(
8274        constant ggml_metal_kargs_mul_mv & args,
8275        device const char * src0,
8276        device const char * src1,
8277        device       char * dst,
8278        threadgroup  char * shmem [[threadgroup(0)]],
8279        uint3  tgpig[[threadgroup_position_in_grid]],
8280        ushort tiisg[[thread_index_in_simdgroup]],
8281        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8282
8283    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8284}
8285
8286template<int NR0, typename args_t>
8287void kernel_mul_mv_mxfp4_f32_impl(
8288        args_t args,
8289        device const char * src0,
8290        device const char * src1,
8291        device       char * dst,
8292        threadgroup  char * shmem,
8293        uint3  tgpig,
8294        ushort tiisg,
8295        ushort sgitg) {
8296    const short NSG = FC_mul_mv_nsg;
8297
8298    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8299
8300    const int r0 = tgpig.x;
8301    const int r1 = tgpig.y;
8302    const int im = tgpig.z;
8303
8304    const int first_row = (r0 * NSG + sgitg) * NR0;
8305
8306    const uint i12 = im%args.ne12;
8307    const uint i13 = im/args.ne12;
8308
8309    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8310    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
8311
8312    device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
8313    device const float       * y = (device const float       *) (src1 + offset1);
8314
8315    const int nb   = args.ne00/QK_MXFP4;
8316    const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
8317
8318    const short ix = tiisg/2;  // 0...15
8319    const short it = tiisg%2;  // 0 or 1
8320
8321    shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
8322    threadgroup_barrier(mem_flags::mem_threadgroup);
8323
8324    float4 yl[4];
8325    float sumf[NR0]={0.f};
8326
8327    device const float * yb = y + ix*QK_MXFP4 + it*8;
8328
8329    // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
8330    //       no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
8331    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
8332        device const float4 * y4 = (device const float4 *) yb;
8333
8334        yl[0] = y4[0];
8335        yl[1] = y4[4];
8336        yl[2] = y4[1];
8337        yl[3] = y4[5];
8338
8339        FOR_UNROLL (short row = 0; row < NR0; row++) {
8340            device const block_mxfp4 & xb = x[row*ns01 + ib];
8341            device const uint8_t     * q2 = (device const uint8_t *)(xb.qs + 8*it);
8342
8343            float4 acc1 = yl[0]*float4(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
8344            float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
8345            float4 acc3 = yl[2]*float4(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
8346            float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
8347
8348            acc1 = (acc1 + acc3) + (acc2 + acc4);
8349
8350            sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
8351        }
8352
8353        yb += 16 * QK_MXFP4;
8354    }
8355
8356    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8357
8358    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8359        float sum_all = simd_sum(sumf[row]);
8360        if (tiisg == 0) {
8361            dst_f32[first_row + row] = sum_all;
8362        }
8363    }
8364}
8365
8366[[host_name("kernel_mul_mv_mxfp4_f32")]]
8367kernel void kernel_mul_mv_mxfp4_f32(
8368        constant ggml_metal_kargs_mul_mv & args,
8369        device const char * src0,
8370        device const char * src1,
8371        device       char * dst,
8372        threadgroup  char * shmem [[threadgroup(0)]],
8373        uint3  tgpig[[threadgroup_position_in_grid]],
8374        ushort tiisg[[thread_index_in_simdgroup]],
8375        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8376
8377    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8378}
8379
8380template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
8381kernel void kernel_get_rows_q(
8382        constant ggml_metal_kargs_get_rows & args,
8383        device const void * src0,
8384        device const void * src1,
8385        device       void * dst,
8386        uint3               tgpig[[threadgroup_position_in_grid]],
8387        ushort              tiitg[[thread_index_in_threadgroup]],
8388        ushort3             ntg  [[threads_per_threadgroup]]) {
8389    const int32_t iw0 = tgpig.x/args.ne10;
8390    const int32_t i10 = tgpig.x%args.ne10;
8391    const int32_t i11 = tgpig.y;
8392    const int32_t i12 = tgpig.z;
8393
8394    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
8395
8396    const int32_t i02 = i11;
8397    const int32_t i03 = i12;
8398
8399    auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);
8400    auto pdst = (device      float4x4 *) ((      device char *) dst  + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);
8401
8402    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
8403        float4x4 temp;
8404        dequantize_func(psrc + ind/nl, ind%nl, temp);
8405        pdst[ind] = temp;
8406
8407        break;
8408    }
8409}
8410
8411template<typename T0, typename T>
8412kernel void kernel_get_rows_f(
8413        constant ggml_metal_kargs_get_rows & args,
8414        device const void * src0,
8415        device const void * src1,
8416        device       void * dst,
8417        uint3               tgpig[[threadgroup_position_in_grid]],
8418        ushort              tiitg[[thread_index_in_threadgroup]],
8419        ushort3             ntg [[threads_per_threadgroup]]) {
8420    const int32_t iw0 = tgpig.x/args.ne10;
8421    const int32_t i10 = tgpig.x%args.ne10;
8422    const int32_t i11 = tgpig.y;
8423    const int32_t i12 = tgpig.z;
8424
8425    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
8426
8427    const int32_t i02 = i11;
8428    const int32_t i03 = i12;
8429
8430    auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);
8431    auto pdst = (      device T  *) ((      device char *)  dst + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);
8432
8433    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
8434        pdst[ind] = psrc[ind];
8435
8436        break;
8437    }
8438}
8439
8440template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
8441kernel void kernel_set_rows_q32(
8442        constant ggml_metal_kargs_set_rows & args,
8443        device const  void * src0,
8444        device const  void * src1,
8445        device       float * dst,
8446        uint3                tgpig[[threadgroup_position_in_grid]],
8447        uint                 tiitg[[thread_index_in_threadgroup]],
8448        uint3                tptg [[threads_per_threadgroup]]) {
8449    const int32_t i03 = tgpig.z;
8450    const int32_t i02 = tgpig.y;
8451
8452    const int32_t i12 = i03%args.ne12;
8453    const int32_t i11 = i02%args.ne11;
8454
8455    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
8456    if (i01 >= args.ne01) {
8457        return;
8458    }
8459
8460    const int32_t i10 = i01;
8461    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
8462
8463          device block_q * dst_row = (      device block_q *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
8464    const device float   * src_row = (const device float   *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
8465
8466    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
8467        quantize_func(src_row + 32*ind, dst_row[ind]);
8468    }
8469}
8470
8471template<typename T, typename TI>
8472kernel void kernel_set_rows_f(
8473        constant ggml_metal_kargs_set_rows & args,
8474        device const  void * src0,
8475        device const  void * src1,
8476        device       float * dst,
8477        uint3                tgpig[[threadgroup_position_in_grid]],
8478        uint                 tiitg[[thread_index_in_threadgroup]],
8479        uint3                tptg [[threads_per_threadgroup]]) {
8480    const int32_t i03 = tgpig.z;
8481    const int32_t i02 = tgpig.y;
8482
8483    const int32_t i12 = i03%args.ne12;
8484    const int32_t i11 = i02%args.ne11;
8485
8486    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
8487    if (i01 >= args.ne01) {
8488        return;
8489    }
8490
8491    const int32_t i10 = i01;
8492    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
8493
8494          device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
8495    const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
8496
8497    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
8498        dst_row[ind] = (T) src_row[ind];
8499    }
8500}
8501
8502kernel void kernel_diag_f32(
8503        constant ggml_metal_kargs_diag & args,
8504        device   const char * src0,
8505        device         char * dst,
8506        uint3  tgpig[[threadgroup_position_in_grid]],
8507        ushort tiitg[[thread_index_in_threadgroup]]) {
8508    constexpr short NW = N_SIMDWIDTH;
8509
8510    const int32_t i3 = tgpig.z;
8511    const int32_t i2 = tgpig.y;
8512    const int32_t i1 = tgpig.x;
8513
8514    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);
8515    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);
8516
8517    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
8518        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
8519    }
8520}
8521
8522constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
8523constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
8524
8525// each block_q contains 16*nl weights
8526template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8527kernel void kernel_mul_mm(
8528        constant ggml_metal_kargs_mul_mm & args,
8529        device const char * src0,
8530        device const char * src1,
8531        device       char * dst,
8532        threadgroup  char * shmem [[threadgroup(0)]],
8533        uint3  tgpig[[threadgroup_position_in_grid]],
8534        ushort tiitg[[thread_index_in_threadgroup]],
8535        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8536
8537    threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8538    threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8539
8540    threadgroup float * sc = (threadgroup float *)(shmem);
8541
8542    constexpr int NR0 = 64;
8543    constexpr int NR1 = 32;
8544
8545    constexpr int NK  = 32;
8546    constexpr int NL0 = NK/16;
8547    constexpr int NL1 = NK/8;
8548
8549    const int im = tgpig.z;
8550    const int r0 = tgpig.y*NR0;
8551    const int r1 = tgpig.x*NR1;
8552
8553    // if this block is of 64x32 shape or smaller
8554    const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8555    const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
8556
8557    // a thread shouldn't load data outside of the matrix
8558    const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8559    const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8560
8561    const short il0 = (tiitg % NL0);
8562
8563    short il = il0;
8564
8565    const int i12 = im%args.ne12;
8566    const int i13 = im/args.ne12;
8567
8568    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8569    const short    offset1 = il0/nl;
8570
8571    device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
8572
8573    const short iy = 8*(tiitg % NL1);
8574
8575    device const T1 * y = (device const T1 *)(src1
8576        + args.nb13*i13
8577        + args.nb12*i12
8578        + args.nb11*(r1 + lr1)
8579        + args.nb10*iy);
8580
8581#ifndef GGML_METAL_HAS_TENSOR
8582    S0_8x8 ma[4];
8583    S1_8x8 mb[2];
8584
8585    simdgroup_float8x8 mc[8];
8586
8587    for (short i = 0; i < 8; i++){
8588        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8589    }
8590#else
8591    auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0));
8592    auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8593
8594    mpp::tensor_ops::matmul2d<
8595        mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8596        execution_simdgroups<4>> mm;
8597
8598    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8599#endif
8600
8601    for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8602#ifndef GGML_METAL_HAS_TENSOR
8603        // load data and store to threadgroup memory
8604        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8605            threadgroup_barrier(mem_flags::mem_threadgroup);
8606
8607            // no need for dequantization
8608            for (short i = 0; i < 16; i++) {
8609                const short sx = 2*il0 + i/8;
8610                const short sy = (tiitg/NL0)/8;
8611
8612              //const short lx = i%8;
8613              //const short ly = (tiitg/NL0)%8;
8614                const short lx = (tiitg/NL0)%8;
8615                const short ly = i%8;
8616
8617                const short ib = 8*sx + sy;
8618
8619                *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8620            }
8621        } else {
8622            S0_4x4 temp_a;
8623            dequantize_func(x, il, temp_a);
8624
8625            threadgroup_barrier(mem_flags::mem_threadgroup);
8626
8627            FOR_UNROLL (short i = 0; i < 16; i++) {
8628                const short sx = 2*il0 + i/8;
8629                const short sy = (tiitg/NL0)/8;
8630
8631              //const short lx = i%8;
8632              //const short ly = (tiitg/NL0)%8;
8633                const short lx = (tiitg/NL0)%8;
8634                const short ly = i%8;
8635
8636                const short ib = 8*sx + sy;
8637
8638                // NOTE: this is massively slower.. WTF?
8639                //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8640
8641                *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
8642            }
8643        }
8644
8645        if (FC_mul_mm_bc_inp) {
8646            for (short i = 0; i < 8; ++i) {
8647                const short sx = (tiitg%NL1);
8648                const short sy = (tiitg/NL1)/8;
8649
8650                const short lx = i;
8651                const short ly = (tiitg/NL1)%8;
8652              //const short lx = (tiitg/NL1)%8;
8653              //const short ly = i;
8654
8655                const short ib = 4*sx + sy;
8656
8657                *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8658            }
8659        } else {
8660            const short sx = (tiitg%NL1);
8661            const short sy = (tiitg/NL1)/8;
8662
8663            const short dx = sx;
8664            const short dy = sy;
8665
8666            const short ly = (tiitg/NL1)%8;
8667
8668            const short ib = 4*sx + sy;
8669
8670            *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
8671        }
8672#else
8673        // load data and store to threadgroup memory
8674        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8675            threadgroup_barrier(mem_flags::mem_threadgroup);
8676
8677            // no need for dequantization
8678            for (short i = 0; i < 16; i++) {
8679                const short sx = 2*il0 + i/8;
8680                const short sy = (tiitg/NL0)/8;
8681
8682                const short lx = i%8;
8683                const short ly = (tiitg/NL0)%8;
8684                //const short lx = (tiitg/NL0)%8;
8685                //const short ly = i%8;
8686
8687                *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8688            }
8689        } else {
8690            S0_4x4 temp_a;
8691            dequantize_func(x, il, temp_a);
8692
8693            threadgroup_barrier(mem_flags::mem_threadgroup);
8694
8695            FOR_UNROLL (short i = 0; i < 16; i++) {
8696                const short sx = 2*il0 + i/8;
8697                const short sy = (tiitg/NL0)/8;
8698
8699                const short lx = i%8;
8700                const short ly = (tiitg/NL0)%8;
8701                //const short lx = (tiitg/NL0)%8;
8702                //const short ly = i%8;
8703
8704                *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
8705            }
8706        }
8707
8708        if (FC_mul_mm_bc_inp) {
8709            for (short i = 0; i < 8; ++i) {
8710                const short sx = (tiitg%NL1);
8711                const short sy = (tiitg/NL1)/8;
8712
8713                const short lx = i;
8714                const short ly = (tiitg/NL1)%8;
8715                //const short lx = (tiitg/NL1)%8;
8716                //const short ly = i;
8717
8718                *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8719            }
8720        } else {
8721            const short sx = (tiitg%NL1);
8722            const short sy = (tiitg/NL1)/8;
8723
8724            //const short lx = i;
8725            const short ly = (tiitg/NL1)%8;
8726            //const short lx = (tiitg/NL1)%8;
8727            //const short ly = i;
8728
8729            *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
8730        }
8731#endif
8732
8733        il = (il + 2 < nl) ? il + 2 : il % 2;
8734        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
8735
8736        y += NK;
8737
8738        threadgroup_barrier(mem_flags::mem_threadgroup);
8739
8740#ifndef GGML_METAL_HAS_TENSOR
8741        // load matrices from threadgroup memory and conduct outer products
8742        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8743        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8744
8745        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8746            simdgroup_barrier(mem_flags::mem_none);
8747
8748            FOR_UNROLL (short i = 0; i < 4; i++) {
8749                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8750            }
8751
8752            simdgroup_barrier(mem_flags::mem_none);
8753
8754            FOR_UNROLL (short i = 0; i < 2; i++) {
8755                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8756            }
8757
8758            simdgroup_barrier(mem_flags::mem_none);
8759
8760            FOR_UNROLL (short i = 0; i < 8; i++){
8761                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8762            }
8763
8764            lsma += 8*64;
8765            lsmb += 4*64;
8766        }
8767#else
8768        auto sA = tA.slice(0, 0);
8769        auto sB = tB.slice(0, 0);
8770
8771        mm.run(sB, sA, cT);
8772#endif
8773    }
8774
8775    if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
8776        // if no bounds checks on the output are needed, we can directly write to device memory
8777#ifdef GGML_METAL_HAS_TENSOR
8778        device float * C = (device float *) dst +
8779            r0 + \
8780            r1 * args.ne0 + im*args.ne1*args.ne0;
8781
8782        auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
8783        cT.store(tC);
8784#else
8785        device float * C = (device float *) dst +
8786            (r0 + 32*(sgitg &  1)) + \
8787            (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
8788
8789        for (short i = 0; i < 8; i++) {
8790            simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
8791        }
8792#endif
8793    } else {
8794        // block is smaller than 64x32, we should avoid writing data outside of the matrix
8795        threadgroup_barrier(mem_flags::mem_threadgroup);
8796
8797        threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
8798
8799#ifdef GGML_METAL_HAS_TENSOR
8800        auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
8801        cT.store(tC);
8802#else
8803        for (short i = 0; i < 8; i++) {
8804            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
8805        }
8806#endif
8807
8808        threadgroup_barrier(mem_flags::mem_threadgroup);
8809
8810        if (sgitg == 0) {
8811            for (int j = tiitg; j < nr1; j += NR1) {
8812                device float  * D  = (device float  *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
8813                device float4 * D4 = (device float4 *) D;
8814
8815                threadgroup float  * C  = temp_str + (j*NR0);
8816                threadgroup float4 * C4 = (threadgroup float4 *) C;
8817
8818                int i = 0;
8819                for (; i < nr0/4; i++) {
8820                    *(D4 + i) = *(C4 + i);
8821                }
8822
8823                i *= 4;
8824                for (; i < nr0; i++) {
8825                    *(D + i) = *(C + i);
8826                }
8827            }
8828        }
8829    }
8830}
8831
8832template<short ne20> // n_expert_used
8833kernel void kernel_mul_mm_id_map0(
8834        constant ggml_metal_kargs_mul_mm_id_map0 & args,
8835        device  const char * src2,
8836        device        char * htpe,
8837        device        char * hids,
8838        threadgroup   char * shmem [[threadgroup(0)]],
8839        ushort tpitg[[thread_position_in_threadgroup]],
8840        ushort   ntg[[threads_per_threadgroup]]) {
8841    const short ide = tpitg; // expert id
8842
8843    uint32_t n_all = 0;
8844
8845    device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
8846
8847    for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
8848        if (i21 + tpitg < args.ne21) {
8849            device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
8850
8851            threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
8852
8853            #pragma unroll(ne20)
8854            for (short i20 = 0; i20 < ne20; i20++) {
8855                sids[i20] = src2_i32[i20];
8856            }
8857        }
8858
8859        threadgroup_barrier(mem_flags::mem_threadgroup);
8860
8861        for (short t = 0; t < ntg; t++) {
8862            if (i21 + t >= args.ne21) {
8863                break;
8864            }
8865
8866            threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
8867
8868            short sel = 0;
8869            #pragma unroll(ne20)
8870            for (short i20 = 0; i20 < ne20; i20++) {
8871                sel += (sids[i20] == ide)*(i20 + 1);
8872            }
8873
8874            ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
8875
8876            n_all += sel > 0;
8877        }
8878
8879        threadgroup_barrier(mem_flags::mem_threadgroup);
8880    }
8881
8882    device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
8883    tpe_u32[ide] = n_all;
8884}
8885
8886typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
8887
8888template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
8889template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
8890template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
8891template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
8892template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
8893template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
8894template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
8895template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
8896
8897template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8898kernel void kernel_mul_mm_id(
8899        constant ggml_metal_kargs_mul_mm_id & args,
8900        device const char * src0,
8901        device const char * src1,
8902        device const char * htpe,
8903        device const char * hids,
8904        device       char * dst,
8905        threadgroup  char * shmem [[threadgroup(0)]],
8906        uint3  tgpig[[threadgroup_position_in_grid]],
8907        ushort tiitg[[thread_index_in_threadgroup]],
8908        ushort tiisg[[thread_index_in_simdgroup]],
8909        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8910    threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8911    threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8912
8913    threadgroup float * sc = (threadgroup float *)(shmem);
8914
8915    constexpr int NR0 = 64;
8916    constexpr int NR1 = 32;
8917
8918    constexpr int NK  = 32;
8919    constexpr int NL0 = NK/16;
8920    constexpr int NL1 = NK/8;
8921
8922    const int im = tgpig.z; // expert
8923    const int r0 = tgpig.y*NR0;
8924    const int r1 = tgpig.x*NR1;
8925
8926    device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
8927    device const int32_t  * ids_i32 = (device const int32_t  *) (hids);
8928
8929    const int32_t neh1 = tpe_u32[im];
8930
8931    if (r1 >= neh1) {
8932        return;
8933    }
8934
8935    // if this block is of 64x32 shape or smaller
8936    const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8937    const short nr1 = (    neh1 - r1 < NR1) ? (    neh1 - r1) : NR1;
8938
8939    // a thread shouldn't load data outside of the matrix
8940    const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8941    const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8942
8943    const short il0 = (tiitg % NL0);
8944
8945    short il = il0;
8946
8947    const int id = ids_i32[im*args.ne21 + r1 + lr1];
8948
8949    const short i11 = (id % args.ne20) % args.ne11;
8950    const short i12 = (id / args.ne20);
8951    const short i13 = 0;
8952
8953    const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
8954    const short    offset1 = il0/nl;
8955
8956    device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
8957
8958    const short iy = 8*(tiitg % NL1);
8959
8960    device const T1 * y = (device const T1 *)(src1
8961        + args.nb13*i13
8962        + args.nb12*i12
8963        + args.nb11*i11
8964        + args.nb10*iy);
8965
8966#ifndef GGML_METAL_HAS_TENSOR
8967    S0_8x8 ma[4];
8968    S1_8x8 mb[2];
8969
8970    simdgroup_float8x8 mc[8];
8971
8972    for (short i = 0; i < 8; i++){
8973        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8974    }
8975#else
8976    auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0));
8977    auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8978
8979    mpp::tensor_ops::matmul2d<
8980        mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8981        execution_simdgroups<4>> mm;
8982
8983    auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8984#endif
8985
8986    for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8987#ifndef GGML_METAL_HAS_TENSOR
8988        // load data and store to threadgroup memory
8989        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8990            threadgroup_barrier(mem_flags::mem_threadgroup);
8991
8992            // no need for dequantization
8993            for (short i = 0; i < 16; i++) {
8994                const short sx = 2*il0 + i/8;
8995                const short sy = (tiitg/NL0)/8;
8996
8997              //const short lx = i%8;
8998              //const short ly = (tiitg/NL0)%8;
8999                const short lx = (tiitg/NL0)%8;
9000                const short ly = i%8;
9001
9002                const short ib = 8*sx + sy;
9003
9004                *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9005            }
9006        } else {
9007            S0_4x4 temp_a;
9008            dequantize_func(x, il, temp_a);
9009
9010            threadgroup_barrier(mem_flags::mem_threadgroup);
9011
9012            FOR_UNROLL (short i = 0; i < 16; i++) {
9013                const short sx = 2*il0 + i/8;
9014                const short sy = (tiitg/NL0)/8;
9015
9016              //const short lx = i%8;
9017              //const short ly = (tiitg/NL0)%8;
9018                const short lx = (tiitg/NL0)%8;
9019                const short ly = i%8;
9020
9021                const short ib = 8*sx + sy;
9022
9023                // NOTE: this is massively slower.. WTF?
9024                //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
9025
9026                *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
9027            }
9028        }
9029
9030        if (FC_mul_mm_bc_inp) {
9031            for (short i = 0; i < 8; ++i) {
9032                const short sx = (tiitg%NL1);
9033                const short sy = (tiitg/NL1)/8;
9034
9035                const short lx = i;
9036                const short ly = (tiitg/NL1)%8;
9037              //const short lx = (tiitg/NL1)%8;
9038              //const short ly = i;
9039
9040                const short ib = 4*sx + sy;
9041
9042                *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9043            }
9044        } else {
9045            const short sx = (tiitg%NL1);
9046            const short sy = (tiitg/NL1)/8;
9047
9048            const short dx = sx;
9049            const short dy = sy;
9050
9051            const short ly = (tiitg/NL1)%8;
9052
9053            const short ib = 4*sx + sy;
9054
9055            *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
9056        }
9057#else
9058        // load data and store to threadgroup memory
9059        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9060            threadgroup_barrier(mem_flags::mem_threadgroup);
9061
9062            // no need for dequantization
9063            for (short i = 0; i < 16; i++) {
9064                const short sx = 2*il0 + i/8;
9065                const short sy = (tiitg/NL0)/8;
9066
9067                const short lx = i%8;
9068                const short ly = (tiitg/NL0)%8;
9069                //const short lx = (tiitg/NL0)%8;
9070                //const short ly = i%8;
9071
9072                *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9073            }
9074        } else {
9075            S0_4x4 temp_a;
9076            dequantize_func(x, il, temp_a);
9077
9078            threadgroup_barrier(mem_flags::mem_threadgroup);
9079
9080            FOR_UNROLL (short i = 0; i < 16; i++) {
9081                const short sx = 2*il0 + i/8;
9082                const short sy = (tiitg/NL0)/8;
9083
9084                const short lx = i%8;
9085                const short ly = (tiitg/NL0)%8;
9086                //const short lx = (tiitg/NL0)%8;
9087                //const short ly = i%8;
9088
9089                *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
9090            }
9091        }
9092
9093        if (FC_mul_mm_bc_inp) {
9094            for (short i = 0; i < 8; ++i) {
9095                const short sx = (tiitg%NL1);
9096                const short sy = (tiitg/NL1)/8;
9097
9098                const short lx = i;
9099                const short ly = (tiitg/NL1)%8;
9100                //const short lx = (tiitg/NL1)%8;
9101                //const short ly = i;
9102
9103                *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9104            }
9105        } else {
9106            const short sx = (tiitg%NL1);
9107            const short sy = (tiitg/NL1)/8;
9108
9109            //const short lx = i;
9110            const short ly = (tiitg/NL1)%8;
9111            //const short lx = (tiitg/NL1)%8;
9112            //const short ly = i;
9113
9114            *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
9115        }
9116#endif
9117
9118        il = (il + 2 < nl) ? il + 2 : il % 2;
9119        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
9120
9121        y += NK;
9122
9123        threadgroup_barrier(mem_flags::mem_threadgroup);
9124
9125#ifndef GGML_METAL_HAS_TENSOR
9126        // load matrices from threadgroup memory and conduct outer products
9127        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9128        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
9129
9130        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
9131            simdgroup_barrier(mem_flags::mem_none);
9132
9133            FOR_UNROLL (short i = 0; i < 4; i++) {
9134                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
9135            }
9136
9137            simdgroup_barrier(mem_flags::mem_none);
9138
9139            FOR_UNROLL (short i = 0; i < 2; i++) {
9140                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
9141            }
9142
9143            simdgroup_barrier(mem_flags::mem_none);
9144
9145            FOR_UNROLL (short i = 0; i < 8; i++){
9146                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
9147            }
9148
9149            lsma += 8*64;
9150            lsmb += 4*64;
9151        }
9152#else
9153        auto sA = tA.slice(0, 0);
9154        auto sB = tB.slice(0, 0);
9155
9156        mm.run(sB, sA, cT);
9157#endif
9158    }
9159
9160    // block is smaller than 64x32, we should avoid writing data outside of the matrix
9161    threadgroup_barrier(mem_flags::mem_threadgroup);
9162
9163#ifdef GGML_METAL_HAS_TENSOR
9164    auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9165    cT.store(tC);
9166#else
9167    threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
9168
9169    for (short i = 0; i < 8; i++) {
9170        simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
9171    }
9172#endif
9173
9174    threadgroup_barrier(mem_flags::mem_threadgroup);
9175
9176    for (short j = sgitg; j < nr1; j += 4) {
9177        const int id = ids_i32[im*args.ne21 + r1 + j];
9178
9179        const short ide = id % args.ne20;
9180        const short idt = id / args.ne20;
9181
9182        device float  * D  = (device float  *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
9183        device float4 * D4 = (device float4 *) D;
9184
9185        threadgroup float  * C  = (threadgroup float  *) shmem + j*NR0;
9186        threadgroup float4 * C4 = (threadgroup float4 *) C;
9187
9188        int i = tiisg;
9189        for (; i < nr0/4; i += 32) {
9190            *(D4 + i) = *(C4 + i);
9191        }
9192
9193        i = (4*(nr0/4)) + tiisg;
9194        for (; i < nr0; i += 32) {
9195            *(D + i) = *(C + i);
9196        }
9197    }
9198}
9199
9200#define QK_NL 16
9201
9202//
9203// get rows
9204//
9205
9206typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
9207
9208template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float, float>;
9209template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half,  float>;
9210template [[host_name("kernel_get_rows_i32")]]  kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
9211#if defined(GGML_METAL_HAS_BF16)
9212template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
9213#endif
9214
9215typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
9216
9217template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_0,    2, dequantize_q4_0>;
9218template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_1,    2, dequantize_q4_1>;
9219template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>;
9220template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>;
9221template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>;
9222template [[host_name("kernel_get_rows_mxfp4")]]   kernel get_rows_q_t kernel_get_rows_q<block_mxfp4,   2, dequantize_mxfp4>;
9223template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>;
9224template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>;
9225template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>;
9226template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_K,    QK_NL, dequantize_q5_K>;
9227template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q6_K,    QK_NL, dequantize_q6_K>;
9228template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
9229template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
9230template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
9231template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq3_s,   QK_NL, dequantize_iq3_s>;
9232template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq2_s,   QK_NL, dequantize_iq2_s>;
9233template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_s,   QK_NL, dequantize_iq1_s>;
9234template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_m,   QK_NL, dequantize_iq1_m>;
9235template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;
9236template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
9237
9238//
9239// set rows
9240//
9241
9242typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
9243
9244template [[host_name("kernel_set_rows_f32_i64")]]  kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
9245template [[host_name("kernel_set_rows_f32_i32")]]  kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
9246template [[host_name("kernel_set_rows_f16_i64")]]  kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
9247template [[host_name("kernel_set_rows_f16_i32")]]  kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
9248#if defined(GGML_METAL_HAS_BF16)
9249template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
9250template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
9251#endif
9252
9253typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
9254
9255template [[host_name("kernel_set_rows_q8_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0,   quantize_q8_0>;
9256template [[host_name("kernel_set_rows_q8_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0,   quantize_q8_0>;
9257template [[host_name("kernel_set_rows_q4_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0,   quantize_q4_0>;
9258template [[host_name("kernel_set_rows_q4_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0,   quantize_q4_0>;
9259template [[host_name("kernel_set_rows_q4_1_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1,   quantize_q4_1>;
9260template [[host_name("kernel_set_rows_q4_1_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1,   quantize_q4_1>;
9261template [[host_name("kernel_set_rows_q5_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0,   quantize_q5_0>;
9262template [[host_name("kernel_set_rows_q5_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0,   quantize_q5_0>;
9263template [[host_name("kernel_set_rows_q5_1_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1,   quantize_q5_1>;
9264template [[host_name("kernel_set_rows_q5_1_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1,   quantize_q5_1>;
9265template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
9266template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
9267
9268//
9269// matrix-matrix multiplication
9270//
9271
9272typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
9273
9274template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  float, float2x4>;
9275template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   float, float2x4>;
9276#if defined(GGML_METAL_HAS_BF16)
9277template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16,    bfloat, bfloat4x4, float, float2x4>;
9278#endif
9279template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  float, float2x4>;
9280template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  float, float2x4>;
9281template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  float, float2x4>;
9282template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  float, float2x4>;
9283template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  float, float2x4>;
9284template [[host_name("kernel_mul_mm_mxfp4_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  float, float2x4>;
9285template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  float, float2x4>;
9286template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  float, float2x4>;
9287template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  float, float2x4>;
9288template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  float, float2x4>;
9289template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  float, float2x4>;
9290template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  float, float2x4>;
9291template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  float, float2x4>;
9292template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  float, float2x4>;
9293template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  float, float2x4>;
9294template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  float, float2x4>;
9295template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  float, float2x4>;
9296template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  float, float2x4>;
9297template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  float, float2x4>;
9298template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  float, float2x4>;
9299
9300template [[host_name("kernel_mul_mm_f32_f16")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  half, half2x4>;
9301template [[host_name("kernel_mul_mm_f16_f16")]]     kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   half, half2x4>;
9302template [[host_name("kernel_mul_mm_q4_0_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  half, half2x4>;
9303template [[host_name("kernel_mul_mm_q4_1_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  half, half2x4>;
9304template [[host_name("kernel_mul_mm_q5_0_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  half, half2x4>;
9305template [[host_name("kernel_mul_mm_q5_1_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  half, half2x4>;
9306template [[host_name("kernel_mul_mm_q8_0_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  half, half2x4>;
9307template [[host_name("kernel_mul_mm_mxfp4_f16")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  half, half2x4>;
9308template [[host_name("kernel_mul_mm_q2_K_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  half, half2x4>;
9309template [[host_name("kernel_mul_mm_q3_K_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  half, half2x4>;
9310template [[host_name("kernel_mul_mm_q4_K_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  half, half2x4>;
9311template [[host_name("kernel_mul_mm_q5_K_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  half, half2x4>;
9312template [[host_name("kernel_mul_mm_q6_K_f16")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  half, half2x4>;
9313template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  half, half2x4>;
9314template [[host_name("kernel_mul_mm_iq2_xs_f16")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  half, half2x4>;
9315template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  half, half2x4>;
9316template [[host_name("kernel_mul_mm_iq3_s_f16")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  half, half2x4>;
9317template [[host_name("kernel_mul_mm_iq2_s_f16")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  half, half2x4>;
9318template [[host_name("kernel_mul_mm_iq1_s_f16")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  half, half2x4>;
9319template [[host_name("kernel_mul_mm_iq1_m_f16")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  half, half2x4>;
9320template [[host_name("kernel_mul_mm_iq4_nl_f16")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  half, half2x4>;
9321template [[host_name("kernel_mul_mm_iq4_xs_f16")]]  kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  half, half2x4>;
9322
9323//
9324// indirect matrix-matrix multiplication
9325//
9326
9327typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
9328
9329template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  float, float2x4>;
9330template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   float, float2x4>;
9331#if defined(GGML_METAL_HAS_BF16)
9332template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16,    bfloat, bfloat4x4, float, float2x4>;
9333#endif
9334template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  float, float2x4>;
9335template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  float, float2x4>;
9336template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  float, float2x4>;
9337template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  float, float2x4>;
9338template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  float, float2x4>;
9339template [[host_name("kernel_mul_mm_id_mxfp4_f32")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  float, float2x4>;
9340template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  float, float2x4>;
9341template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  float, float2x4>;
9342template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  float, float2x4>;
9343template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  float, float2x4>;
9344template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  float, float2x4>;
9345template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  float, float2x4>;
9346template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  float, float2x4>;
9347template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  float, float2x4>;
9348template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  float, float2x4>;
9349template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  float, float2x4>;
9350template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  float, float2x4>;
9351template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  float, float2x4>;
9352template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  float, float2x4>;
9353template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  float, float2x4>;
9354
9355template [[host_name("kernel_mul_mm_id_f32_f16")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32,     float,  float4x4,  half, half2x4>;
9356template [[host_name("kernel_mul_mm_id_f16_f16")]]     kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16,     half,   half4x4,   half, half2x4>;
9357template [[host_name("kernel_mul_mm_id_q4_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0,    float,  float4x4,  half, half2x4>;
9358template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1,    float,  float4x4,  half, half2x4>;
9359template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0,    float,  float4x4,  half, half2x4>;
9360template [[host_name("kernel_mul_mm_id_q5_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1,    float,  float4x4,  half, half2x4>;
9361template [[host_name("kernel_mul_mm_id_q8_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0,    float,  float4x4,  half, half2x4>;
9362template [[host_name("kernel_mul_mm_id_mxfp4_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4,   float,  float4x4,  half, half2x4>;
9363template [[host_name("kernel_mul_mm_id_q2_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K,    float,  float4x4,  half, half2x4>;
9364template [[host_name("kernel_mul_mm_id_q3_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K,    float,  float4x4,  half, half2x4>;
9365template [[host_name("kernel_mul_mm_id_q4_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K,    float,  float4x4,  half, half2x4>;
9366template [[host_name("kernel_mul_mm_id_q5_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K,    float,  float4x4,  half, half2x4>;
9367template [[host_name("kernel_mul_mm_id_q6_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K,    float,  float4x4,  half, half2x4>;
9368template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float,  float4x4,  half, half2x4>;
9369template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs,  float,  float4x4,  half, half2x4>;
9370template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float,  float4x4,  half, half2x4>;
9371template [[host_name("kernel_mul_mm_id_iq3_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s,   float,  float4x4,  half, half2x4>;
9372template [[host_name("kernel_mul_mm_id_iq2_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s,   float,  float4x4,  half, half2x4>;
9373template [[host_name("kernel_mul_mm_id_iq1_s_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s,   float,  float4x4,  half, half2x4>;
9374template [[host_name("kernel_mul_mm_id_iq1_m_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m,   float,  float4x4,  half, half2x4>;
9375template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl,  float,  float4x4,  half, half2x4>;
9376template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]]  kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   half,   half2x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs,  float,  float4x4,  half, half2x4>;
9377
9378//
9379// matrix-vector multiplication
9380//
9381
9382typedef void (kernel_mul_mv_disp_t)(
9383        ggml_metal_kargs_mul_mv args,
9384        device const char * src0,
9385        device const char * src1,
9386        device       char * dst,
9387        uint3  tgpig,
9388        ushort tiisg);
9389
9390typedef void (kernel_mul_mv2_disp_t)(
9391        ggml_metal_kargs_mul_mv args,
9392        device const char * src0,
9393        device const char * src1,
9394        device       char * dst,
9395        threadgroup  char * shmem,
9396        uint3  tgpig,
9397        ushort tiisg,
9398        ushort sgitg);
9399
9400template<kernel_mul_mv_disp_t disp_fn>
9401void mmv_fn(
9402        ggml_metal_kargs_mul_mv args,
9403        device const char * src0,
9404        device const char * src1,
9405        device       char * dst,
9406        threadgroup  char * shmem,
9407        uint3  tgpig,
9408        ushort tiitg,
9409        ushort tiisg,
9410        ushort sgitg) {
9411    disp_fn(args, src0, src1, dst, tgpig, tiisg);
9412}
9413
9414template<kernel_mul_mv2_disp_t disp_fn>
9415void mmv_fn(
9416        ggml_metal_kargs_mul_mv args,
9417        device const char * src0,
9418        device const char * src1,
9419        device       char * dst,
9420        threadgroup  char * shmem,
9421        uint3  tgpig,
9422        ushort tiitg,
9423        ushort tiisg,
9424        ushort sgitg) {
9425    disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
9426}
9427
9428typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
9429
9430template<mul_mv_disp_fn_t disp_fn>
9431kernel void kernel_mul_mv_id(
9432        constant ggml_metal_kargs_mul_mv_id & args,
9433        device const char * src0s,
9434        device const char * src1,
9435        device       char * dst,
9436        device const char * ids,
9437        threadgroup  char * shmem [[threadgroup(0)]],
9438        uint3  tgpig[[threadgroup_position_in_grid]],
9439        ushort tiitg[[thread_index_in_threadgroup]],
9440        ushort tiisg[[thread_index_in_simdgroup]],
9441        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
9442    const int iid1 = tgpig.z/args.nei0;
9443    const int idx  = tgpig.z%args.nei0;
9444
9445    tgpig.z = 0;
9446
9447    const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
9448
9449    const int64_t i11 = idx % args.ne11;
9450    const int64_t i12 = iid1;
9451
9452    const int64_t i1 = idx;
9453    const int64_t i2 = i12;
9454
9455    device const char * src0_cur = src0s + i02*args.nb02;
9456    device const char * src1_cur = src1  + i11*args.nb11 + i12*args.nb12;
9457
9458    device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
9459
9460    ggml_metal_kargs_mul_mv args0 = {
9461        /*.ne00 =*/ args.ne00,
9462        /*.ne01 =*/ args.ne01,
9463        /*.ne02 =*/ 1, // args.ne02,
9464        /*.nb00 =*/ args.nb00,
9465        /*.nb01 =*/ args.nb01,
9466        /*.nb02 =*/ args.nb02,
9467        /*.nb03 =*/ args.nb02, // args.ne02 == 1
9468        /*.ne10 =*/ args.ne10,
9469        /*.ne11 =*/ 1, // args.ne11,
9470        /*.ne12 =*/ 1, // args.ne12,
9471        /*.nb10 =*/ args.nb10,
9472        /*.nb11 =*/ args.nb11,
9473        /*.nb12 =*/ args.nb12,
9474        /*.nb13 =*/ args.nb12, // ne12 == 1
9475        /*.ne0  =*/ args.ne0,
9476        /*.ne1  =*/ 1, // args.ne1,
9477        /*.nr0  =*/ args.nr0,
9478        /*.r2   =*/ 1,
9479        /*.r3   =*/ 1,
9480    };
9481
9482    disp_fn(
9483        args0,
9484        /* src0 */ src0_cur,
9485        /* src1 */ src1_cur,
9486        /* dst  */ dst_cur,
9487        shmem,
9488        tgpig,
9489        tiitg,
9490        tiisg,
9491        sgitg);
9492}
9493
9494typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
9495
9496typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
9497
9498template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
9499template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half,  float>>>;
9500#if defined(GGML_METAL_HAS_BF16)
9501template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
9502#endif
9503template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
9504template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half,  half4,  float, float4>>>;
9505#if defined(GGML_METAL_HAS_BF16)
9506template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
9507#endif
9508
9509template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
9510
9511template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
9512template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
9513template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
9514template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
9515
9516template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
9517
9518template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K>>>;
9519template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K>>>;
9520template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K>>>;
9521template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K>>>;
9522template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K>>>;
9523template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S>>>;
9524template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M>>>;
9525template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
9526template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
9527template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
9528template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S>>>;
9529template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S>>>;
9530template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
9531template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
9532
9533kernel void kernel_pool_2d_max_f32(
9534        constant    ggml_metal_kargs_pool_2d & args,
9535        device  const float * src0,
9536        device        float * dst,
9537        uint        gid[[thread_position_in_grid]]) {
9538
9539    if (gid >= args.np) {
9540        return;
9541    }
9542
9543    const int idx = gid;
9544    const int I_HW = args.IH * args.IW;
9545    const int O_HW = args.OH * args.OW;
9546    const int nc = idx / O_HW;
9547    const int cur_oh = idx % O_HW / args.OW;
9548    const int cur_ow = idx % O_HW % args.OW;
9549
9550    device const float * i_ptr = src0 + nc * I_HW;
9551    device       float * o_ptr = dst  + nc * O_HW;
9552
9553    const int start_h = cur_oh * args.s1 - args.p1;
9554    const int bh = MAX(0,  start_h);
9555    const int eh = MIN(args.IH, start_h + args.k1);
9556    const int start_w = cur_ow * args.s0 - args.p0;
9557    const int bw = MAX(0,  start_w);
9558    const int ew = MIN(args.IW, start_w + args.k0);
9559
9560    float res = -INFINITY;
9561
9562    for (int i = bh; i < eh; i += 1) {
9563        for (int j = bw; j < ew; j += 1) {
9564            res = MAX(res, i_ptr[i * args.IW + j]);
9565        }
9566    }
9567
9568    o_ptr[cur_oh * args.OW + cur_ow] = res;
9569}
9570
9571kernel void kernel_pool_2d_avg_f32(
9572        constant    ggml_metal_kargs_pool_2d & args,
9573        device  const float * src0,
9574        device        float * dst,
9575        uint        gid[[thread_position_in_grid]]) {
9576
9577    if (gid >= args.np) {
9578        return;
9579    }
9580
9581    const int idx = gid;
9582    const int I_HW = args.IH * args.IW;
9583    const int O_HW = args.OH * args.OW;
9584    const int nc = idx / O_HW;
9585    const int cur_oh = idx % O_HW / args.OW;
9586    const int cur_ow = idx % O_HW % args.OW;
9587
9588    device const float * i_ptr = src0 + nc * I_HW;
9589    device       float * o_ptr = dst  + nc * O_HW;
9590
9591    const int start_h = cur_oh * args.s1 - args.p1;
9592    const int bh = MAX(0,  start_h);
9593    const int eh = MIN(args.IH, start_h + args.k1);
9594    const int start_w = cur_ow * args.s0 - args.p0;
9595    const int bw = MAX(0,  start_w);
9596    const int ew = MIN(args.IW, start_w + args.k0);
9597    // const float scale = 1. / ((eh - bh) * (ew - bw));
9598    const float scale = 1. / (args.k0 * args.k1);
9599
9600    float res = 0;
9601
9602    for (int i = bh; i < eh; i += 1) {
9603        for (int j = bw; j < ew; j += 1) {
9604            float cur = i_ptr[i * args.IW + j];
9605            res += cur * scale;
9606        }
9607    }
9608
9609    o_ptr[cur_oh * args.OW + cur_ow] = res;
9610}
9611
9612
9613kernel void kernel_pool_1d_max_f32(
9614        constant        ggml_metal_kargs_pool_1d & args,
9615        device  const   float * src,
9616        device          float * dst,
9617        uint            gid [[thread_position_in_grid]]
9618) {
9619
9620    if (gid >= args.np) {
9621        return;
9622    }
9623
9624    const int ow  = (int)gid % args.OW;
9625    const int row = (int)gid / args.OW;
9626
9627    const int base = ow * args.s0 - args.p0;
9628
9629    float acc = -INFINITY;
9630
9631    const int src_off = row * args.IW;
9632    const int dst_off = row * args.OW;
9633
9634    for (int ki = 0; ki < args.k0; ++ki) {
9635        int j = base + ki;
9636        if (j < 0 || j >= args.IW){
9637            continue;
9638        }
9639        float v = src[src_off + j];
9640        acc = max(acc, v);
9641    }
9642
9643    dst[dst_off + ow] = acc;
9644}
9645
9646kernel void kernel_pool_1d_avg_f32(
9647        constant        ggml_metal_kargs_pool_1d & args,
9648        device  const   float * src,
9649        device          float * dst,
9650        uint            gid [[thread_position_in_grid]]
9651) {
9652
9653    if (gid >= args.np) {
9654        return;
9655    }
9656
9657    const int ow  = (int)gid % args.OW;
9658    const int row = (int)gid / args.OW;
9659
9660    const int base = ow * args.s0 - args.p0;
9661
9662    float acc = 0.0f;
9663    int   cnt = 0;
9664
9665    const int src_off = row * args.IW;
9666    const int dst_off = row * args.OW;
9667
9668    for (int ki = 0; ki < args.k0; ++ki) {
9669        const int j = base + ki;
9670        if (j < 0 || j >= args.IW) {
9671            continue;
9672        }
9673        acc += src[src_off + j];
9674        cnt += 1;
9675    }
9676
9677    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
9678}
9679
9680kernel void kernel_opt_step_adamw_f32(
9681        constant    ggml_metal_kargs_opt_step_adamw & args,
9682        device       float * x,
9683        device const float * g,
9684        device       float * g_m,
9685        device       float * g_v,
9686        device const float * pars,
9687        uint        gid[[thread_position_in_grid]]) {
9688
9689    if (gid >= args.np) {
9690        return;
9691    }
9692
9693    const float alpha  = pars[0];
9694    const float beta1  = pars[1];
9695    const float beta2  = pars[2];
9696    const float eps    = pars[3];
9697    const float wd     = pars[4];
9698    const float beta1h = pars[5];
9699    const float beta2h = pars[6];
9700
9701    const float gi = g[gid];
9702    const float gmi = g_m[gid] * beta1 +      gi * (1.0f - beta1);
9703    const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
9704
9705    g_m[gid] = gmi;
9706    g_v[gid] = gvi;
9707
9708    const float mh =      gmi * beta1h;
9709    const float vh = sqrt(gvi * beta2h) + eps;
9710
9711    x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
9712}
9713
9714kernel void kernel_opt_step_sgd_f32(
9715        constant    ggml_metal_kargs_opt_step_sgd & args,
9716        device       float * x,
9717        device const float * g,
9718        device const float * pars,
9719        uint        gid[[thread_position_in_grid]]) {
9720
9721    if (gid >= args.np) {
9722        return;
9723    }
9724
9725    x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
9726}
9727
9728template<typename T>
9729kernel void kernel_memset(
9730        constant ggml_metal_kargs_memset & args,
9731        device T * dst,
9732        uint tpig[[thread_position_in_grid]]) {
9733    dst[tpig] = args.val;
9734}
9735
9736typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
9737
9738template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
9739
9740constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
9741
9742template<typename T>
9743kernel void kernel_count_equal(
9744        constant ggml_metal_kargs_count_equal & args,
9745        device   const char * src0,
9746        device   const char * src1,
9747        device   atomic_int * dst,
9748        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
9749        uint3   tgpig[[threadgroup_position_in_grid]],
9750        ushort3 tpitg[[thread_position_in_threadgroup]],
9751        ushort  sgitg[[simdgroup_index_in_threadgroup]],
9752        ushort  tiisg[[thread_index_in_simdgroup]],
9753        ushort3   ntg[[threads_per_threadgroup]]) {
9754    const short NSG = FC_count_equal_nsg;
9755
9756    const int i3 = tgpig.z;
9757    const int i2 = tgpig.y;
9758    const int i1 = tgpig.x;
9759
9760    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
9761        return;
9762    }
9763
9764    int sum = 0;
9765
9766    device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
9767    device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
9768
9769    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
9770        const T v0 = *(device const T *)(base0 + i0*args.nb00);
9771        const T v1 = *(device const T *)(base1 + i0*args.nb10);
9772        sum += (v0 == v1);
9773    }
9774
9775    sum = simd_sum(sum);
9776
9777    if (tiisg == 0) {
9778        shmem_i32[sgitg] = sum;
9779    }
9780
9781    threadgroup_barrier(mem_flags::mem_threadgroup);
9782
9783    if (sgitg == 0) {
9784        float v = 0.0f;
9785        if (tpitg.x < NSG) {
9786            v = shmem_i32[tpitg.x];
9787        }
9788
9789        float total = simd_sum(v);
9790        if (tpitg.x == 0) {
9791            atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
9792        }
9793    }
9794}
9795
9796typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
9797
9798template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;