1// Old and deprecated WMMA FlashAttention implementation.
  2// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
  3// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
  4
  5#include "common.cuh"
  6#include "fattn-common.cuh"
  7#include "fattn-wmma-f16.cuh"
  8
  9#ifdef GGML_USE_WMMA_FATTN
 10#if !defined(GGML_USE_HIP)
 11#include <mma.h>
 12#if defined(GGML_USE_MUSA)
 13namespace wmma = mtmusa::wmma;
 14#else // GGML_USE_MUSA
 15namespace wmma = nvcuda::wmma;
 16#endif // GGML_USE_MUSA
 17#elif defined(GGML_USE_HIP)
 18#include <rocwmma/rocwmma.hpp>
 19namespace wmma = rocwmma;
 20#endif // !defined(GGML_USE_HIP)
 21#endif // GGML_USE_WMMA_FATTN
 22
 23// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 24template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
 25__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
 26static __global__ void flash_attn_ext_f16(
 27        const char * __restrict__ Q,
 28        const char * __restrict__ K,
 29        const char * __restrict__ V,
 30        const char * __restrict__ mask,
 31        const char * __restrict__ sinks,
 32        const int  * __restrict__ KV_max,
 33        float      * __restrict__ dst,
 34        float2     * __restrict__ dst_meta,
 35        const float scale,
 36        const float max_bias,
 37        const float m0,
 38        const float m1,
 39        const uint32_t n_head_log2,
 40        const float logit_softcap,
 41        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
 42                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
 43        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
 44                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
 45                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
 46                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
 47                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 48#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
 49    // Skip unused kernel variants for faster compilation:
 50    if (use_logit_softcap && !(D == 128 || D == 256)) {
 51        NO_DEVICE_CODE;
 52        return;
 53    }
 54
 55    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 56
 57    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 58
 59    const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
 60
 61    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
 62    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
 63    constexpr int frag_m = ncols == 8 ? 32 : 16;
 64    constexpr int frag_n = ncols == 8 ?  8 : 16;
 65    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
 66    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
 67    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
 68    typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, half, wmma::col_major> frag_b;
 69    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
 70    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
 71
 72    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
 73    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
 74    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
 75
 76    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
 77    constexpr int D_padded = D + 8;
 78    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
 79    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 80
 81    const int sequence = blockIdx.z / ne02;
 82    const int head = blockIdx.z - sequence*ne02;
 83    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 84    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
 85    const half  * K_h    = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
 86    const half  * V_h    = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
 87    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
 88    const half2 * mask2  = (const half2 *)  maskh;
 89    const float * sinksf = (const float *) sinks;
 90
 91    const int stride_Q  = nb01 / sizeof(float);
 92    const int stride_KV = nb11 / sizeof(half);
 93
 94    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 95    const half  slopeh = __float2half(slopef);
 96    const half2 slope2 = make_half2(slopef, slopef);
 97
 98    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
 99
100    frag_b Q_b[D/16][ncols/frag_n];
101
102    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
103    constexpr int mem_KQ = ncols*kqs_padded*kqar;
104    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
105    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
106    float * KQ_f = (float *) KQ;
107    half2 * KQ2 = (half2 *) KQ;
108
109    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
110    float       KQ_max_f[ncols/nwarps];
111    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
112
113#pragma unroll
114    for (int j = 0; j < ncols/nwarps; ++j) {
115        KQ_max_f[j] = -FLT_MAX/2.0f;
116    }
117
118    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
119    half2       KQ_max_h2[ncols/nwarps];
120    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
121
122#pragma unroll
123    for (int j = 0; j < ncols/nwarps; ++j) {
124        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
125    }
126
127    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
128    half2 * VKQ2 = (half2 *) VKQ;
129#pragma unroll
130    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
131        const int j = j0 + threadIdx.y;
132#pragma unroll
133        for (int i0 = 0; i0 < D/2; i0 += warp_size) {
134            const int i = i0 + threadIdx.x;
135            if (i0 + warp_size > D/2 && i >= D/2) {
136                break;
137            }
138            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
139        }
140    }
141
142    // Convert Q to half and apply scale, temporarily store in KQ:
143#pragma unroll
144    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
145        const int j = j0 + threadIdx.y;
146#pragma unroll
147        for (int i0 = 0; i0 < D; i0 += warp_size) {
148            const int i = i0 + threadIdx.x;
149            if (i0 + warp_size > D && i >= D) {
150                break;
151            }
152            KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
153        }
154    }
155
156    __syncthreads();
157
158    // Load Q into tensor core fragments/registers since it will be used frequently:
159#pragma unroll
160    for (int i0 = 0; i0 < D; i0 += 16) {
161#pragma unroll
162        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
163            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
164        }
165    }
166
167    __syncthreads();
168
169    // Iterate over ne11 == previous tokens:
170    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
171    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
172        // Calculate tile of KQ:
173#pragma unroll
174        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
175            frag_c_KQ KQ_c[ncols/frag_n];
176#pragma unroll
177            for (int j = 0; j < ncols/frag_n; ++j) {
178                wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
179            }
180#pragma unroll
181            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
182                frag_a_K K_a;
183                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
184#pragma unroll
185                for (int j = 0; j < ncols/frag_n; ++j) {
186                    wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
187                }
188            }
189#pragma unroll
190            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
191                wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
192            }
193        }
194
195        __syncthreads();
196
197        // Calculate softmax for each KQ column using the current max. value.
198        // The divisor is stored in KQ_rowsum and will be applied at the end.
199#pragma unroll
200        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
201            const int j = j0 + threadIdx.y;
202
203            if (std::is_same<KQ_acc_t, float>::value) {
204                float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
205#pragma unroll
206                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
207                    const int k = k0 + threadIdx.x;
208
209                    KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
210
211                    if (use_logit_softcap) {
212                        KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
213                    }
214                }
215
216                float KQ_max_new = KQ_max_f[j0/nwarps];
217#pragma unroll
218                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
219                    const int k = k0 + threadIdx.x;
220
221                    KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
222                        __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
223                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
224                }
225                KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
226
227                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
228                KQ_max_scale_f[j0/nwarps] = expf(diff);
229                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
230                    KQ_max_scale_f[j0/nwarps] = 0.0f;
231                }
232                KQ_max_f[j0/nwarps] = KQ_max_new;
233
234                float KQ_rowsum_add = 0.0f;
235#pragma unroll
236                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
237                    const int k = k0 + threadIdx.x;
238
239                    const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
240                    KQ_f_tmp[k0/warp_size] = expf(diff);
241                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
242                        KQ_f_tmp[k0/warp_size] = 0.0f;
243                    }
244                    KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
245                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
246                }
247                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
248
249                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
250                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
251            } else {
252                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
253#pragma unroll
254                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
255                    const int k = k0 + threadIdx.x;
256
257                    KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
258
259                    if (use_logit_softcap) {
260                        // There is no dedicated tangens hyperbolicus function for half2.
261                        KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
262                        KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
263                                               /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
264
265                        KQ2_tmp[k0/warp_size] *= logit_softcap_2;
266                    }
267                }
268
269                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
270#pragma unroll
271                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
272                    const int k = k0 + threadIdx.x;
273
274                    KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
275                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
276                }
277                KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
278                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
279                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
280                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
281                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
282                KQ_max_h2[j0/nwarps] = KQ_max_new;
283
284                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
285#pragma unroll
286                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
287                    const int k = k0 + threadIdx.x;
288
289                    const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
290                    KQ2_tmp[k0/warp_size] = h2exp(diff);
291                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
292                    *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
293                    KQ_rowsum_add += KQ2_tmp[k0/warp_size];
294                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
295                }
296                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
297
298                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
299                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
300            }
301        }
302
303        __syncthreads();
304
305        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
306#pragma unroll
307        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
308#pragma unroll
309            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
310                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
311                wmma::load_matrix_sync(
312                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
313                    KQ + j0*(kqar*kqs_padded) + k,
314                    kqar*kqs_padded);
315            }
316        }
317
318        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
319#pragma unroll
320        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
321#pragma unroll
322            for (int j = 0; j < ncols/frag_n; ++j) {
323                wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
324            }
325
326#pragma unroll
327            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
328                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
329
330                frag_a_V v_a;
331                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
332#pragma unroll
333                for (int j = 0; j < ncols/frag_n; ++j) {
334                    wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
335                }
336            }
337        }
338
339        __syncthreads();
340
341        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
342#pragma unroll
343        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
344#pragma unroll
345            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
346                wmma::store_matrix_sync(
347                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
348                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349                    D_padded, wmma::mem_col_major);
350            }
351        }
352
353        __syncthreads();
354
355#pragma unroll
356        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
357            const int j = j0 + threadIdx.y;
358
359            half2 VKQ_scale;
360            if (std::is_same<KQ_acc_t, float>::value) {
361                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
362            } else {
363                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
364            }
365
366#pragma unroll
367            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
368                const int i = i0 + threadIdx.x;
369                if (i0 + warp_size > D/2 && i >= D/2) {
370                    break;
371                }
372
373                half2 VKQ_add = make_half2(0.0f, 0.0f);
374#pragma unroll
375                for (int l = 0; l < VKQ_ratio; ++l) {
376                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
377                }
378                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
379            }
380        }
381
382        __syncthreads();
383    }
384
385    // Apply attention sinks
386    if (sinksf && blockIdx.y == 0) {
387        const float sinkf = sinksf[head];
388        const half  sinkh = __float2half(sinkf);
389
390#pragma unroll
391        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
392            const int j = j0 + threadIdx.y;
393
394            if (std::is_same<KQ_acc_t, float>::value) {
395                float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
396
397                const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
398                KQ_max_f[j0/nwarps] = kqmax_new;
399
400                KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
401
402                const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
403#pragma unroll
404                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
405                    const int i = i0 + threadIdx.x;
406                    if (i0 + warp_size > D/2 && i >= D/2) break;
407                    VKQ2[j*(D_padded/2) + i] *= scale_h2;
408                }
409            } else {
410                half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
411                half kqmax_new = fmaxf(kqmax_old, sinkh);
412                KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
413
414                const half  KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
415                const half2 KQ_max_scale   = __half2half2(KQ_max_scale_h);
416
417                KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
418                const half val = hexp(sinkh - kqmax_new);
419                KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
420
421#pragma unroll
422                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
423                    const int i = i0 + threadIdx.x;
424                    if (i0 + warp_size > D/2 && i >= D/2) break;
425                    VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
426                }
427            }
428        }
429
430        __syncthreads();
431    }
432#pragma unroll
433    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
434        const int j_VKQ = j0 + threadIdx.y;
435        if (ic0 + j_VKQ >= int(ne01.z)) {
436            return;
437        }
438
439        float KQ_rowsum_j;
440        if (std::is_same<KQ_acc_t, float>::value) {
441            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
442        } else {
443            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
444        }
445
446        const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
447
448#pragma unroll
449        for (int i0 = 0; i0 < D; i0 += warp_size) {
450            const int i = i0 + threadIdx.x;
451            if (i0 + warp_size > D && i >= D) {
452                break;
453            }
454            float dst_val = VKQ[j_VKQ*D_padded + i];
455            if (gridDim.y == 1) {
456                dst_val /= KQ_rowsum_j;
457            }
458            dst[j_dst_unrolled*D + i] = dst_val;
459        }
460
461        if (gridDim.y == 1 || threadIdx.x != 0) {
462            continue;
463        }
464
465        float2 dst_meta_val;
466        if (std::is_same<KQ_acc_t, float>::value) {
467            dst_meta_val.x = KQ_max_f[j0/nwarps];
468        } else {
469            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
470        }
471        dst_meta_val.y = KQ_rowsum_j;
472        dst_meta[j_dst_unrolled] = dst_meta_val;
473    }
474#else
475    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
476        max_bias, m0, m1, n_head_log2, logit_softcap,
477        ne00, ne01, ne02, ne03,
478              nb01, nb02, nb03,
479        ne10, ne11, ne12, ne13,
480              nb11, nb12, nb13,
481              nb21, nb22, nb23,
482              ne31, ne32, ne33,
483              nb31, nb32, nb33);
484    NO_DEVICE_CODE;
485#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
486}
487
488constexpr int get_max_power_of_2(int x) {
489    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
490}
491
492static_assert(get_max_power_of_2(1) == 1, "Test failed.");
493static_assert(get_max_power_of_2(2) == 2, "Test failed.");
494static_assert(get_max_power_of_2(4) == 4, "Test failed.");
495static_assert(get_max_power_of_2(6) == 2, "Test failed.");
496
497// Number of VKQ rows calculated in parallel:
498constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
499    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
500}
501
502static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
503static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
504static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
505static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
506static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
507static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
508static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
509static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
510static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
511
512template <int D, int cols_per_block, typename KQ_acc_t>
513void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
514    const ggml_tensor * KQV = dst;
515
516    constexpr int nwarps = 4;
517
518    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
519    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
520
521    float logit_softcap;
522    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
523
524    fattn_kernel_t fattn_kernel;
525    if (logit_softcap == 0.0f) {
526        constexpr bool use_logit_softcap = false;
527        fattn_kernel = flash_attn_ext_f16<
528            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
529    } else {
530        constexpr bool use_logit_softcap = true;
531        fattn_kernel = flash_attn_ext_f16<
532            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
533    }
534    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
535}
536
537void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
538    const ggml_tensor * KQV = dst;
539    const ggml_tensor * Q   = dst->src[0];
540
541    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
542    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
543
544    if (prec != GGML_PREC_DEFAULT) {
545        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
546            constexpr int cols_per_block = 16;
547            switch (Q->ne[0]) {
548                case 64:
549                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
550                    break;
551                case 80:
552                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
553                    break;
554                case 96:
555                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
556                    break;
557                case 112:
558                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
559                    break;
560                case 128:
561                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562                    break;
563                case 256:
564                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
565                    break;
566                default:
567                    GGML_ABORT("fatal error");
568                    break;
569            }
570        } else {
571            constexpr int cols_per_block = 32;
572            switch (Q->ne[0]) {
573                case 64:
574                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
575                    break;
576                case 80:
577                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
578                    break;
579                case 96:
580                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
581                    break;
582                case 112:
583                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
584                    break;
585                case 128:
586                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
587                    break;
588                // case 256:
589                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
590                //     break;
591                default:
592                    GGML_ABORT("fatal error");
593                    break;
594            }
595        }
596        return;
597    }
598
599#if !defined(GGML_USE_HIP)
600    if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
601        constexpr int cols_per_block = 8;
602        switch (Q->ne[0]) {
603            case 64:
604                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
605                break;
606            case 96:
607                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
608                break;
609            case 128:
610                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
611                break;
612            case 256:
613                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
614                break;
615            default:
616                GGML_ABORT("fatal error");
617                break;
618        }
619        return;
620    }
621#endif // !defined(GGML_USE_HIP)
622
623    if (Q->ne[1] <= 32) {
624        constexpr int cols_per_block = 16;
625        switch (Q->ne[0]) {
626            case 64:
627                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
628                break;
629            case 80:
630                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
631                break;
632            case 96:
633                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
634                break;
635            case 112:
636                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
637                break;
638            case 128:
639                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
640                break;
641            case 256:
642                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
643                break;
644            default:
645                GGML_ABORT("fatal error");
646                break;
647        }
648        return;
649    }
650
651    constexpr int cols_per_block = 32;
652    switch (Q->ne[0]) {
653        case 64:
654            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
655            break;
656        case 80:
657            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
658            break;
659        case 96:
660            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
661            break;
662        case 112:
663            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
664            break;
665        case 128:
666            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
667            break;
668        case 256:
669            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
670            break;
671        default:
672            GGML_ABORT("fatal error");
673            break;
674    }
675}