1#pragma once
   2
   3#include "common.cuh"
   4#include "convert.cuh"
   5#include "vecdotq.cuh"
   6
   7#include <cstdint>
   8
   9#define FATTN_KQ_STRIDE       256
  10#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
  11#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
  12
  13// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
  14//     by the VKQ accumulators is effectively being shifted up by a factor of 2.
  15// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
  16// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
  17// Still, the value range should be shifted as much as necessary but as little as possible.
  18// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
  19#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
  20
  21typedef void (* fattn_kernel_t)(
  22        const char * __restrict__ Q,
  23        const char * __restrict__ K,
  24        const char * __restrict__ V,
  25        const char * __restrict__ mask,
  26        const char * __restrict__ sinks,
  27        const int  * __restrict__ KV_max,
  28        float      * __restrict__ dst,
  29        float2     * __restrict__ dst_meta,
  30        const float scale,
  31        const float max_bias,
  32        const float m0,
  33        const float m1,
  34        const uint32_t n_head_log2,
  35        const float logit_softcap,
  36        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
  37                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
  38        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
  39                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
  40                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
  41                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
  42                            const int32_t nb31, const int32_t nb32, const int64_t nb33);
  43
  44typedef float (*vec_dot_KQ_t)(
  45    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
  46
  47template <int D, int nthreads>
  48static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
  49    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
  50
  51    const half2 * K_h2 = (const half2 *) K_c;
  52    GGML_UNUSED(Q_q8);
  53    GGML_UNUSED(Q_ds_v);
  54
  55    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
  56    constexpr int cpy_ne = cpy_nb / 4;
  57
  58    float sum = 0.0f;
  59
  60#pragma unroll
  61    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
  62        __align__(16) half2 tmp[cpy_ne];
  63        ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
  64#pragma unroll
  65        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
  66#ifdef V_DOT2_F32_F16_AVAILABLE
  67            ggml_cuda_mad(sum,                tmp[k_KQ_1] , ((const half2  *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
  68#else
  69            ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
  70#endif // V_DOT2_F32_F16_AVAILABLE
  71        }
  72    }
  73
  74    return sum;
  75}
  76
  77template<int D, int nthreads>
  78static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
  79    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
  80
  81    const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
  82    GGML_UNUSED(Q_v);
  83
  84    float sum = 0.0f;
  85
  86#pragma unroll
  87    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
  88        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
  89
  90        const int ib    = k_KQ /  QI8_1;
  91        const int iqs4  = k_KQ %  QI4_0;
  92        const int shift = k_KQ & (QI8_1/2);
  93
  94        int v;
  95        ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
  96        v = (v >> shift) & 0x0F0F0F0F;
  97        const int u = Q_q8[k_KQ_0/nthreads];
  98
  99        const int sumi = ggml_cuda_dp4a(v, u, 0);
 100
 101        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
 102        sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
 103    }
 104
 105    return sum;
 106}
 107
 108template<int D, int nthreads>
 109static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
 110    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 111
 112    const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
 113    GGML_UNUSED(Q_v);
 114
 115    float sum = 0.0f;
 116
 117#pragma unroll
 118    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
 119        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
 120
 121        const int ib    = k_KQ /  QI8_1;
 122        const int iqs4  = k_KQ %  QI4_1;
 123        const int shift = k_KQ & (QI8_1/2);
 124
 125        int v;
 126        ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
 127        v = (v >> shift) & 0x0F0F0F0F;
 128        const int u = Q_q8[k_KQ_0/nthreads];
 129
 130        const int sumi = ggml_cuda_dp4a(v, u, 0);
 131
 132        const float2 K_dm = __half22float2(K_q4_1[ib].dm);
 133        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
 134
 135        sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
 136    }
 137
 138    return sum;
 139}
 140
 141template<int D, int nthreads>
 142static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
 143    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 144
 145    const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
 146    GGML_UNUSED(Q_v);
 147
 148    float sum = 0.0f;
 149
 150#pragma unroll
 151    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
 152        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
 153
 154        const int ib    = k_KQ /  QI8_1;
 155        const int iqs4  = k_KQ %  QI5_0;
 156        const int iqs8  = k_KQ %  QI8_1;
 157        const int shift = k_KQ & (QI8_1/2);
 158
 159        int v;
 160        ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
 161        v = (v >> shift) & 0x0F0F0F0F;
 162
 163        {
 164            int vh;
 165            ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
 166            vh >>= iqs8 * QI5_0;
 167
 168            v |= (vh <<  4) & 0x00000010; // 0 ->  4
 169            v |= (vh << 11) & 0x00001000; // 1 -> 12
 170            v |= (vh << 18) & 0x00100000; // 2 -> 20
 171            v |= (vh << 25) & 0x10000000; // 3 -> 28
 172        }
 173
 174        const int u = Q_q8[k_KQ_0/nthreads];
 175
 176        const int sumi = ggml_cuda_dp4a(v, u, 0);
 177
 178        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
 179
 180        sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
 181    }
 182
 183    return sum;
 184}
 185
 186template<int D, int nthreads>
 187static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
 188    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 189
 190    const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
 191    GGML_UNUSED(Q_v);
 192
 193    float sum = 0.0f;
 194
 195#pragma unroll
 196    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
 197        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
 198
 199        const int ib    = k_KQ /  QI8_1;
 200        const int iqs4  = k_KQ %  QI5_1;
 201        const int iqs8  = k_KQ %  QI8_1;
 202        const int shift = k_KQ & (QI8_1/2);
 203
 204        int v;
 205        ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
 206        v = (v >> shift) & 0x0F0F0F0F;
 207
 208        {
 209            int vh;
 210            ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
 211            vh >>= iqs8 * QI5_0;
 212
 213            v |= (vh <<  4) & 0x00000010; // 0 ->  4
 214            v |= (vh << 11) & 0x00001000; // 1 -> 12
 215            v |= (vh << 18) & 0x00100000; // 2 -> 20
 216            v |= (vh << 25) & 0x10000000; // 3 -> 28
 217        }
 218
 219        const int u = Q_q8[k_KQ_0/nthreads];
 220
 221        const int sumi = ggml_cuda_dp4a(v, u, 0);
 222
 223        const float2 K_dm = __half22float2(K_q5_1[ib].dm);
 224        const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
 225
 226        sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
 227    }
 228
 229    return sum;
 230}
 231
 232template <int D, int nthreads>
 233static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
 234    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 235
 236    const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
 237    GGML_UNUSED(Q_v);
 238
 239    float sum = 0.0f;
 240
 241#pragma unroll
 242    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
 243        const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
 244
 245        const int ib  = k_KQ / QI8_0;
 246        const int iqs = k_KQ % QI8_0;
 247
 248        int v;
 249        ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
 250
 251        const float2 * Q_ds = (const float2 *) Q_ds_v;
 252        const float Q_d = Q_ds[k_KQ_0/nthreads].x;
 253
 254        sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
 255    }
 256
 257    return sum;
 258}
 259
 260template <typename Tds, int ni>
 261static __device__ __forceinline__ void quantize_q8_1_to_shared(
 262    const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
 263
 264    float vals[sizeof(int)] = {0.0f};
 265#pragma unroll
 266    for (int l = 0; l < int(sizeof(int)); ++l) {
 267        vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
 268    }
 269
 270    float amax = fabsf(vals[0]);
 271    float sum  = vals[0];
 272#pragma unroll
 273    for (int l = 1; l < int(sizeof(int)); ++l) {
 274        amax = fmaxf(amax, fabsf(vals[l]));
 275        sum += vals[l];
 276    }
 277#pragma unroll
 278    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
 279        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
 280        sum +=             __shfl_xor_sync(0xFFFFFFFF, sum,  mask, 32);
 281    }
 282
 283    const float d = amax / 127;
 284    int q32 = 0;
 285    int8_t * q8 = (int8_t *) &q32;
 286
 287    if (d != 0.0f) {
 288#pragma unroll
 289        for (int l = 0; l < int(sizeof(int)); ++l) {
 290            q8[l] = roundf(vals[l] / d);
 291        }
 292    }
 293
 294    yq32[threadIdx.x] = q32;
 295    if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
 296        if (std::is_same<Tds, half2>::value) {
 297            ((half2  *) yds)[threadIdx.x/QI8_1] =  make_half2(d, sum);
 298        } else {
 299            ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
 300        }
 301    }
 302}
 303
 304typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
 305
 306template <typename T, int ne>
 307static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 308    if constexpr (std::is_same_v<T, half>) {
 309        ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
 310    } else if constexpr (std::is_same_v<T, float>) {
 311        static_assert(ne % 2 == 0, "bad ne");
 312        __align__(16) half2 tmp[ne/2];
 313        ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
 314        float2 * dst_f2 = (float2 *) dst;
 315#pragma unroll
 316        for (int l = 0; l < ne/2; ++l) {
 317            dst_f2[l] = __half22float2(tmp[l]);
 318        }
 319    } else {
 320        static_assert(std::is_same_v<T, void>, "unsupported type");
 321    }
 322}
 323
 324template <typename T, int ne>
 325static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 326    const block_q4_0 * x = (const block_q4_0 *) vx;
 327
 328    const int64_t ib    =  i0          /  QK4_0;
 329    const int     iqs   =  i0          % (QK4_0/2);
 330    const int     shift = (i0 % QK4_0) / (QK4_0/2);
 331
 332    int q;
 333    static_assert(ne == 2 || ne == 4, "bad ne");
 334    ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
 335    q >>= 4*shift;
 336    q &= 0x0F0F0F0F;
 337    q = __vsubss4(q, 0x08080808);
 338
 339    const int8_t * q8 = (const int8_t *) &q;
 340
 341#ifdef FP16_AVAILABLE
 342    if constexpr (std::is_same_v<T, half>) {
 343        const half2 d = __half2half2(x[ib].d);
 344
 345#pragma unroll
 346        for (int l0 = 0; l0 < ne; l0 += 2) {
 347            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
 348        }
 349    } else
 350#endif // FP16_AVAILABLE
 351    if constexpr (std::is_same_v<T, float>) {
 352        const float d = x[ib].d;
 353
 354#pragma unroll
 355        for (int l = 0; l < ne; ++l) {
 356            ((float *) dst)[l] = d * q8[l];
 357        }
 358    } else {
 359        static_assert(std::is_same_v<T, void>, "bad type");
 360    }
 361}
 362
 363template <typename T, int ne>
 364static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 365    const block_q4_1 * x = (const block_q4_1 *) vx;
 366
 367    const int64_t ib    =  i0          /  QK4_1;
 368    const int     iqs   =  i0          % (QK4_1/2);
 369    const int     shift = (i0 % QK4_1) / (QK4_1/2);
 370
 371    int q;
 372    static_assert(ne == 2 || ne == 4, "bad ne");
 373    ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
 374    q >>= 4*shift;
 375    q &= 0x0F0F0F0F;
 376
 377    const int8_t * q8 = (const int8_t *) &q;
 378
 379#ifdef FP16_AVAILABLE
 380    if constexpr (std::is_same_v<T, half>) {
 381        const half2 dm = x[ib].dm;
 382        const half2 d  = __half2half2( __low2half(dm));
 383        const half2 m  = __half2half2(__high2half(dm));
 384
 385#pragma unroll
 386        for (int l0 = 0; l0 < ne; l0 += 2) {
 387            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
 388        }
 389    } else
 390#endif // FP16_AVAILABLE
 391    if constexpr (std::is_same_v<T, float>) {
 392        const float2 dm = __half22float2(x[ib].dm);
 393
 394#pragma unroll
 395        for (int l = 0; l < ne; ++l) {
 396            ((float *) dst)[l] = dm.x * q8[l] + dm.y;
 397        }
 398    } else {
 399        static_assert(std::is_same_v<T, void>, "bad type");
 400    }
 401}
 402
 403template <typename T, int ne>
 404static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 405    const block_q5_0 * x = (const block_q5_0 *) vx;
 406
 407    const int64_t ib    =  i0          /  QK5_0;
 408    const int     idq   =  i0          %  QK5_0;
 409    const int     iqs   =  i0          % (QK5_0/2);
 410    const int     shift = (i0 % QK5_0) / (QK5_0/2);
 411
 412    int q;
 413    static_assert(ne == 2 || ne == 4, "bad ne");
 414    ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
 415    q >>= 4*shift;
 416    q &= 0x0F0F0F0F;
 417
 418    {
 419        int qh;
 420        ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
 421#pragma unroll
 422        for (int l = 0; l < ne; ++l) {
 423            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
 424        }
 425    }
 426
 427    q = __vsubss4(q, 0x10101010);
 428
 429    const int8_t * q8 = (const int8_t *) &q;
 430
 431#ifdef FP16_AVAILABLE
 432    if constexpr (std::is_same_v<T, half>) {
 433        const half2 d = __half2half2(x[ib].d);
 434
 435#pragma unroll
 436        for (int l0 = 0; l0 < ne; l0 += 2) {
 437            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
 438        }
 439    } else
 440#endif // FP16_AVAILABLE
 441    if constexpr (std::is_same_v<T, float>) {
 442        const float d = x[ib].d;
 443
 444#pragma unroll
 445        for (int l = 0; l < ne; ++l) {
 446            ((float *) dst)[l] = d * q8[l];
 447        }
 448    } else {
 449        static_assert(std::is_same_v<T, void>, "bad type");
 450    }
 451}
 452
 453template <typename T, int ne>
 454static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 455    const block_q5_1 * x = (const block_q5_1 *) vx;
 456
 457    const int64_t ib    =  i0          /  QK5_1;
 458    const int     idq   =  i0          %  QK5_1;
 459    const int     iqs   =  i0          % (QK5_1/2);
 460    const int     shift = (i0 % QK5_1) / (QK5_1/2);
 461
 462    int q;
 463    static_assert(ne == 2 || ne == 4, "bad ne");
 464    ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
 465    q >>= 4*shift;
 466    q &= 0x0F0F0F0F;
 467
 468    {
 469        int qh;
 470        ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
 471#pragma unroll
 472        for (int l = 0; l < ne; ++l) {
 473            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
 474        }
 475    }
 476
 477    const int8_t * q8 = (const int8_t *) &q;
 478
 479#ifdef FP16_AVAILABLE
 480    if constexpr (std::is_same_v<T, half>) {
 481        const half2 dm = x[ib].dm;
 482        const half2 d  = __half2half2( __low2half(dm));
 483        const half2 m  = __half2half2(__high2half(dm));
 484
 485#pragma unroll
 486        for (int l0 = 0; l0 < ne; l0 += 2) {
 487            ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
 488        }
 489    } else
 490#endif // FP16_AVAILABLE
 491    if constexpr (std::is_same_v<T, float>) {
 492        const float2 dm = __half22float2(x[ib].dm);
 493
 494#pragma unroll
 495        for (int l = 0; l < ne; ++l) {
 496            ((float *) dst)[l] = dm.x * q8[l] + dm.y;
 497        }
 498    } else {
 499        static_assert(std::is_same_v<T, void>, "bad type");
 500    }
 501}
 502
 503template <typename T, int ne>
 504static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
 505    const block_q8_0 * x = (const block_q8_0 *) vx;
 506
 507    const int64_t ib  = i0 / QK8_0;
 508    const int     iqs = i0 % QK8_0;
 509
 510    static_assert(ne % 2 == 0, "bad ne");
 511    int8_t qs[ne];
 512    ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
 513
 514#ifdef FP16_AVAILABLE
 515    if constexpr (std::is_same<T, half>::value) {
 516        const half2 d = __half2half2(x[ib].d);
 517
 518#pragma unroll
 519        for (int l0 = 0; l0 < ne; l0 += 2) {
 520            ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
 521        }
 522    } else
 523#endif // FP16_AVAILABLE
 524    if constexpr (std::is_same<T, float>::value) {
 525        const float d = x[ib].d;
 526
 527#pragma unroll
 528        for (int l = 0; l < ne; ++l) {
 529            ((float *) dst)[l] = d * qs[l];
 530        }
 531    } else {
 532        static_assert(std::is_same_v<T, void>, "unsupported type");
 533    }
 534}
 535
 536template <ggml_type type_K, int D, int nthreads>
 537constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
 538    if constexpr (type_K == GGML_TYPE_F16) {
 539        return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
 540    } else if constexpr (type_K == GGML_TYPE_Q4_0) {
 541        return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
 542    } else if constexpr (type_K == GGML_TYPE_Q4_1) {
 543        return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
 544    } else if constexpr (type_K == GGML_TYPE_Q5_0) {
 545        return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
 546    } else if constexpr (type_K == GGML_TYPE_Q5_1) {
 547        return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
 548    } else if constexpr (type_K == GGML_TYPE_Q8_0) {
 549        return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
 550    } else {
 551        static_assert(type_K == -1, "bad type");
 552        return nullptr;
 553    }
 554}
 555
 556template <ggml_type type_V, typename T, int ne>
 557constexpr __device__ dequantize_V_t get_dequantize_V() {
 558    if constexpr (type_V == GGML_TYPE_F16) {
 559        return dequantize_V_f16<T, ne>;
 560    } else if constexpr (type_V == GGML_TYPE_Q4_0) {
 561        return dequantize_V_q4_0<T, ne>;
 562    } else if constexpr (type_V == GGML_TYPE_Q4_1) {
 563        return dequantize_V_q4_1<T, ne>;
 564    } else if constexpr (type_V == GGML_TYPE_Q5_0) {
 565        return dequantize_V_q5_0<T, ne>;
 566    } else if constexpr (type_V == GGML_TYPE_Q5_1) {
 567        return dequantize_V_q5_1<T, ne>;
 568    } else if constexpr (type_V == GGML_TYPE_Q8_0) {
 569        return dequantize_V_q8_0<T, ne>;
 570    } else {
 571        static_assert(type_V == -1, "bad type");
 572        return nullptr;
 573    }
 574}
 575
 576template <int ncols1>
 577__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
 578static __global__ void flash_attn_mask_to_KV_max(
 579        const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
 580    const int ne31     = gridDim.x;
 581    const int tid      = threadIdx.x;
 582    const int sequence = blockIdx.y;
 583    const int jt       = blockIdx.x;
 584
 585    mask += sequence*s33 + jt*ncols1*s31;
 586
 587    __shared__ int buf_iw[WARP_SIZE];
 588    if (tid < WARP_SIZE) {
 589        buf_iw[tid] = 1;
 590    }
 591    __syncthreads();
 592
 593    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
 594    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
 595        int all_inf = 1;
 596
 597#pragma unroll
 598        for (int j = 0; j < ncols1; ++j) {
 599            const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
 600            all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
 601        }
 602
 603        all_inf = warp_reduce_all(all_inf);
 604        if (tid % WARP_SIZE == 0) {
 605            buf_iw[tid / WARP_SIZE] = all_inf;
 606        }
 607        __syncthreads();
 608        all_inf = buf_iw[tid % WARP_SIZE];
 609        __syncthreads();
 610        all_inf = warp_reduce_all(all_inf);
 611
 612        if (!all_inf) {
 613            break;
 614        }
 615    }
 616
 617    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
 618    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
 619    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
 620    KV_max_sj += FATTN_KQ_STRIDE;
 621
 622    if (threadIdx.x != 0) {
 623        return;
 624    }
 625
 626    KV_max[sequence*ne31 + jt] = KV_max_sj;
 627}
 628
 629template<int D, int ncols1, int ncols2> // D == head size
 630__launch_bounds__(D, 1)
 631static __global__ void flash_attn_stream_k_fixup(
 632        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
 633        const int ne11, const int ne12, const int nbatch_fa) {
 634    constexpr int ncols = ncols1*ncols2;
 635
 636    const int bidx0 = blockIdx.x;
 637    const int j     = blockIdx.y;
 638    const int c     = blockIdx.z;
 639    const int jc    = j*ncols2 + c;
 640    const int tid   = threadIdx.x;
 641
 642    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 643
 644    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 645
 646    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
 647    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
 648    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
 649
 650    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 651    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 652
 653    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
 654    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
 655    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
 656    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
 657        return;
 658    }
 659
 660    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
 661    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
 662    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
 663    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
 664    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 665
 666    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 667
 668    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
 669        return;
 670    }
 671
 672    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
 673
 674    // Load the partial result that needs a fixup:
 675    float dst_val = 0.0f;
 676    float max_val = 0.0f;
 677    float rowsum  = 0.0f;
 678    {
 679        dst_val = *dst;
 680
 681        const float2 tmp = dst_fixup[bidx0*ncols + jc];
 682        max_val = tmp.x;
 683        rowsum  = tmp.y;
 684    }
 685
 686    // Iterate over previous blocks and compute the combined results.
 687    // All CUDA blocks that get here must have a previous block that needs a fixup.
 688    int bidx = bidx0 - 1;
 689    int kbc_stop = kbc0;
 690    while(true) {
 691        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 692        if (kbc == kbc_stop) { // Did not have any data.
 693            bidx--;
 694            kbc_stop = kbc;
 695            continue;
 696        }
 697
 698        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
 699
 700        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
 701
 702        // Scale the current and new value accumulators depending on the max. values.
 703        const float max_val_new = fmaxf(max_val, tmp.x);
 704
 705        const float diff_val = max_val - max_val_new;
 706        const float diff_add = tmp.x   - max_val_new;
 707
 708        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
 709        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
 710
 711        dst_val = scale_val*dst_val + scale_add*dst_add;
 712        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
 713
 714        max_val = max_val_new;
 715
 716        // If this block started in a previous tile we are done and don't need to combine additional partial results.
 717        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
 718            break;
 719        }
 720        bidx--;
 721        kbc_stop = kbc;
 722    }
 723
 724    // Write back final result:
 725    *dst = dst_val / rowsum;
 726}
 727
 728template<int D> // D == head size
 729__launch_bounds__(D, 1)
 730static __global__ void flash_attn_combine_results(
 731        const float  * __restrict__ VKQ_parts,
 732        const float2 * __restrict__ VKQ_meta,
 733        float * __restrict__ dst,
 734        const int parallel_blocks) {
 735    // Dimension 0: threadIdx.x
 736    // Dimension 1: blockIdx.x
 737    // Dimension 2: blockIdx.y
 738    // Dimension 3: blockIdx.z
 739    // Memory layout is permuted with [0, 2, 1, 3]
 740
 741    const int ne01 = gridDim.x;
 742    const int ne02 = gridDim.y;
 743
 744    const int col      = blockIdx.x;
 745    const int head     = blockIdx.y;
 746    const int sequence = blockIdx.z;
 747
 748    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
 749
 750    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
 751    VKQ_meta  += j_dst_unrolled * parallel_blocks;
 752    dst       += j_dst_unrolled *                 D;
 753
 754    const int tid = threadIdx.x;
 755    __builtin_assume(tid < D);
 756
 757    extern __shared__ float2 meta[];
 758    for (int i = tid; i < 2*parallel_blocks; i += D) {
 759        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
 760    }
 761
 762    __syncthreads();
 763
 764    float kqmax = meta[0].x;
 765    for (int l = 1; l < parallel_blocks; ++l) {
 766        kqmax = max(kqmax, meta[l].x);
 767    }
 768
 769    float VKQ_numerator   = 0.0f;
 770    float VKQ_denominator = 0.0f;
 771    for (int l = 0; l < parallel_blocks; ++l) {
 772        const float KQ_max_scale = expf(meta[l].x - kqmax);
 773
 774        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
 775        VKQ_denominator += KQ_max_scale * meta[l].y;
 776    }
 777
 778    dst[tid] = VKQ_numerator / VKQ_denominator;
 779}
 780
 781template <int DV, int ncols1, int ncols2>
 782void launch_fattn(
 783    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
 784    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 785) {
 786    constexpr int ncols = ncols1 * ncols2;
 787
 788    const ggml_tensor * Q = dst->src[0];
 789    const ggml_tensor * K = dst->src[1];
 790    const ggml_tensor * V = dst->src[2];
 791
 792    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
 793
 794    const ggml_tensor * mask  = dst->src[3];
 795    const ggml_tensor * sinks = dst->src[4];
 796
 797    ggml_tensor * KQV = dst;
 798
 799    GGML_ASSERT(Q->type == GGML_TYPE_F32);
 800    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 801
 802    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
 803    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
 804    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
 805
 806    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 807
 808    ggml_cuda_pool & pool = ctx.pool();
 809    cudaStream_t main_stream = ctx.stream();
 810    const int id  = ggml_cuda_get_device();
 811    const int cc  = ggml_cuda_info().devices[id].cc;
 812    const int nsm = ggml_cuda_info().devices[id].nsm;
 813
 814    ggml_cuda_pool_alloc<half>   K_f16(pool);
 815    ggml_cuda_pool_alloc<half>   V_f16(pool);
 816    ggml_cuda_pool_alloc<int>    KV_max(pool);
 817    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
 818    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 819
 820    const char * K_data = (const char *) K->data;
 821    size_t nb11 = K->nb[1];
 822    size_t nb12 = K->nb[2];
 823    size_t nb13 = K->nb[3];
 824
 825    const char * V_data = (const char *) V->data;
 826    size_t nb21 = V->nb[1];
 827    size_t nb22 = V->nb[2];
 828    size_t nb23 = V->nb[3];
 829
 830    if (need_f16_K && K->type != GGML_TYPE_F16) {
 831        const size_t bs = ggml_blck_size(K->type);
 832        const size_t ts = ggml_type_size(K->type);
 833
 834        K_f16.alloc(ggml_nelements(K));
 835        if (ggml_is_contiguously_allocated(K)) {
 836            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
 837            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
 838
 839            nb11 = nb11*bs*sizeof(half)/ts;
 840            nb12 = nb12*bs*sizeof(half)/ts;
 841            nb13 = nb13*bs*sizeof(half)/ts;
 842        } else {
 843            GGML_ASSERT(K->nb[0] == ts);
 844            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
 845            const int64_t s01 = nb11 / ts;
 846            const int64_t s02 = nb12 / ts;
 847            const int64_t s03 = nb13 / ts;
 848            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
 849
 850            nb11 = K->ne[0] * sizeof(half);
 851            nb12 = K->ne[1] * nb11;
 852            nb13 = K->ne[2] * nb12;
 853        }
 854        K_data = (char *) K_f16.ptr;
 855    }
 856
 857    if (need_f16_V && V->type != GGML_TYPE_F16) {
 858        if (V_is_K_view) {
 859            V_data = K_data;
 860            nb21   = nb11;
 861            nb22   = nb12;
 862            nb23   = nb13;
 863        } else {
 864            const size_t bs = ggml_blck_size(V->type);
 865            const size_t ts = ggml_type_size(V->type);
 866
 867            V_f16.alloc(ggml_nelements(V));
 868            if (ggml_is_contiguously_allocated(V)) {
 869                to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
 870                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
 871                V_data = (char *) V_f16.ptr;
 872
 873                nb21 = nb21*bs*sizeof(half)/ts;
 874                nb22 = nb22*bs*sizeof(half)/ts;
 875                nb23 = nb23*bs*sizeof(half)/ts;
 876            } else {
 877                GGML_ASSERT(V->nb[0] == ts);
 878                to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
 879                const int64_t s01 = nb21 / ts;
 880                const int64_t s02 = nb22 / ts;
 881                const int64_t s03 = nb23 / ts;
 882                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
 883
 884                nb21 = V->ne[0] * sizeof(half);
 885                nb22 = V->ne[1] * nb21;
 886                nb23 = V->ne[2] * nb22;
 887            }
 888            V_data = (char *) V_f16.ptr;
 889        }
 890    }
 891
 892    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
 893    const int gqa_ratio    = Q->ne[2] / K->ne[2];
 894    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
 895    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
 896
 897    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
 898    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
 899    //     multiple sequences of possibly different lengths.
 900    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
 901        const int s31 = mask->nb[1] / sizeof(half2);
 902        const int s33 = mask->nb[3] / sizeof(half2);
 903
 904        const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
 905        const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
 906
 907        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
 908        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
 909
 910        KV_max.alloc(ne_KV_max);
 911        flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
 912            ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
 913        CUDA_CHECK(cudaGetLastError());
 914    }
 915
 916    const dim3 block_dim(warp_size, nwarps, 1);
 917    int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
 918    CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
 919    GGML_ASSERT(max_blocks_per_sm > 0);
 920    int parallel_blocks = max_blocks_per_sm;
 921
 922    dim3 blocks_num;
 923    if (stream_k) {
 924        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
 925        const int max_blocks = max_blocks_per_sm*nsm;
 926        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
 927        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
 928
 929        const int nblocks_stream_k = max_blocks;
 930
 931        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 932
 933        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
 934        blocks_num.y = 1;
 935        blocks_num.z = 1;
 936
 937        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
 938            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
 939        }
 940    } else {
 941        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
 942
 943        // parallel_blocks must not be larger than what the tensor size allows:
 944        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
 945
 946        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
 947        // Test whether parallel_blocks can be set to a higher value for better efficiency.
 948        const int blocks_per_wave = nsm * max_blocks_per_sm;
 949        int nwaves_best = 0;
 950        int efficiency_percent_best = 0;
 951        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
 952            const int nblocks_total = ntiles_total * parallel_blocks_test;
 953            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
 954            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
 955
 956            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
 957            if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
 958                break;
 959            }
 960
 961            if (efficiency_percent > efficiency_percent_best) {
 962                nwaves_best = nwaves;
 963                efficiency_percent_best = efficiency_percent;
 964                parallel_blocks = parallel_blocks_test;
 965            }
 966        }
 967
 968        blocks_num.x = ntiles_x;
 969        blocks_num.y = parallel_blocks;
 970        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
 971
 972        if (parallel_blocks > 1) {
 973            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
 974            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
 975        }
 976    }
 977
 978    float scale         = 1.0f;
 979    float max_bias      = 0.0f;
 980    float logit_softcap = 0.0f;
 981
 982    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
 983    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
 984    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 985
 986    if (logit_softcap != 0.0f) {
 987        scale /= logit_softcap;
 988    }
 989
 990    const uint32_t n_head      = Q->ne[2];
 991    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
 992
 993    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 994    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 995
 996    // TODO other tensor dimensions after removal of WMMA kernel:
 997    const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
 998
 999    GGML_ASSERT(block_dim.x % warp_size == 0);
1000    fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
1001        (const char *) Q->data,
1002        K_data,
1003        V_data,
1004        mask ? ((const char *) mask->data) : nullptr,
1005        sinks ? ((const char *) sinks->data) : nullptr,
1006        KV_max.ptr,
1007        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
1008        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
1009        Q->ne[0], ne01,     Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
1010        K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
1011        nb21, nb22, nb23,
1012        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
1013        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
1014    );
1015    CUDA_CHECK(cudaGetLastError());
1016
1017    if (stream_k) {
1018        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1019            const dim3 block_dim_combine(DV, 1, 1);
1020            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
1021
1022            flash_attn_stream_k_fixup<DV, ncols1, ncols2>
1023                <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1024                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
1025        }
1026    } else if (parallel_blocks > 1) {
1027        const dim3 block_dim_combine(DV, 1, 1);
1028        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
1029        const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
1030
1031        flash_attn_combine_results<DV>
1032            <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
1033            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
1034    }
1035    CUDA_CHECK(cudaGetLastError());
1036}