1#include "common.cuh"
  2#include "fattn-common.cuh"
  3
  4static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
  5    return 128;
  6    GGML_UNUSED(cc);
  7}
  8
  9static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
 10    return 128;
 11}
 12
 13// Currenlty llvm with the amdgcn target does not support unrolling loops
 14// that contain a break that can not be resolved at compile time.
 15#ifdef __clang__
 16#pragma clang diagnostic push
 17#pragma clang diagnostic ignored "-Wpass-failed"
 18#endif // __clang__
 19template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
 20__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
 21static __global__ void flash_attn_ext_vec(
 22        const char * __restrict__ Q,
 23        const char * __restrict__ K,
 24        const char * __restrict__ V,
 25        const char * __restrict__ mask,
 26        const char * __restrict__ sinks,
 27        const int  * __restrict__ KV_max,
 28        float      * __restrict__ dst,
 29        float2     * __restrict__ dst_meta,
 30        const float scale,
 31        const float max_bias,
 32        const float m0,
 33        const float m1,
 34        const uint32_t n_head_log2,
 35        const float logit_softcap,
 36        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
 37                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
 38        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
 39                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
 40                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
 41                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
 42                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 43#ifdef FLASH_ATTN_AVAILABLE
 44
 45    // Skip unused kernel variants for faster compilation:
 46    if (use_logit_softcap && !(D == 128 || D == 256)) {
 47        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
 48            max_bias, m0, m1, n_head_log2, logit_softcap,
 49            ne00, ne01, ne02, ne03,
 50                  nb01, nb02, nb03,
 51            ne10, ne11, ne12, ne13,
 52                  nb11, nb12, nb13,
 53                  nb21, nb22, nb23,
 54                  ne31, ne32, ne33,
 55                  nb31, nb32, nb33);
 56        NO_DEVICE_CODE;
 57        return;
 58    }
 59
 60    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 61
 62    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
 63    constexpr int cpy_ne = cpy_nb / 4;
 64
 65#ifdef GGML_USE_HIP
 66#ifdef RDNA
 67    constexpr int nthreads_KQ_q = 2;
 68#else
 69    constexpr int nthreads_KQ_q = 4;
 70#endif // RDNA
 71    constexpr int nthreads_V_q  = (D/4 < 32 ? D/4 : 32);
 72#else
 73    constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
 74    constexpr int nthreads_V_q  = (D/4 < 32 ? D/4 : 32);
 75#endif // GGML_USE_HIP
 76
 77    constexpr int nthreads    = ggml_cuda_fattn_vec_get_nthreads_device();
 78    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
 79    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
 80
 81    static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
 82    static_assert(WARP_SIZE % nthreads_V  == 0, "bad nthreads_V");
 83
 84    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
 85    constexpr int V_cols_per_iter   = WARP_SIZE / nthreads_V;
 86
 87    constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
 88    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
 89#ifdef V_DOT2_F32_F16_AVAILABLE
 90    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half,  V_rows_per_thread>();
 91#else
 92    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
 93#endif // V_DOT2_F32_F16_AVAILABLE
 94
 95    const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 96
 97    const int sequence = blockIdx.z / ne02;
 98    const int head = blockIdx.z - sequence*ne02;
 99    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
100    Q += nb03*sequence + nb02* head              + nb01*ic0;
101    K += nb13*sequence + nb12*(head / gqa_ratio);
102    V += nb23*sequence + nb22*(head / gqa_ratio);
103
104    const half * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);
105
106    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
107
108    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
109    constexpr int nwarps = nthreads / WARP_SIZE;
110    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
111    __builtin_assume(tid < nthreads);
112
113    constexpr int ne_KQ      = ncols*D;
114    constexpr int ne_combine = nwarps*V_cols_per_iter*D;
115#ifdef V_DOT2_F32_F16_AVAILABLE
116    half2            VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
117    __shared__ half   KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
118#else
119    float2           VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
120    __shared__ float  KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
121#endif // V_DOT2_F32_F16_AVAILABLE
122
123    float KQ_max[ncols];
124    float KQ_sum[ncols];
125#pragma unroll
126    for (int j = 0; j < ncols; ++j) {
127        KQ_max[j] = -FLT_MAX/2.0f;
128        KQ_sum[j] = 0.0f;
129    }
130
131    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
132#ifdef V_DOT2_F32_F16_AVAILABLE
133    half2  Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
134#else
135    __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
136#endif // V_DOT2_F32_F16_AVAILABLE
137    int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
138    float2  Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
139    if constexpr (Q_q8_1) {
140#pragma unroll
141        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
142            const int j = j0 + threadIdx.y;
143
144            if (j0 + nwarps > ncols && j >= ncols) {
145                break;
146            }
147
148            // Reuse KQ as temporary storage for converting Q to q8_1:
149            int    * tmp_q_i32 = (int    *) &KQ[j*D];
150            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
151
152            // Set memory to zero if out of bounds:
153            if (ncols > 1 && ic0 + j >= int(ne01.z)) {
154#pragma unroll
155                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
156                    const int i = i0 + threadIdx.x;
157
158                    if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
159                        tmp_q_i32[i] = 0;
160                    }
161                }
162                if (threadIdx.x < D/QK8_1) {
163                    tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
164                }
165            } else {
166                const float * Q_f = (const float *) (Q + j*nb01);
167                constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
168#pragma unroll
169                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
170                    quantize_q8_1_to_shared<float2, nthreads_quantize>
171                        (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
172                }
173            }
174        }
175
176        __syncthreads();
177
178#pragma unroll
179        for (int j = 0; j < ncols; ++j) {
180            int    * tmp_q_i32 = (int    *) &KQ[j*D];
181            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
182
183#pragma unroll
184            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
185                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
186
187                Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
188                Q_ds[j][i0/nthreads_KQ]  = tmp_q_ds[i/QI8_1];
189            }
190        }
191
192        __syncthreads();
193    } else {
194#ifdef V_DOT2_F32_F16_AVAILABLE
195        const half2 scale_h2 = make_half2(scale, scale);
196#pragma unroll
197        for (int j = 0; j < ncols; ++j) {
198            const float2 * Q_j = (const float2 *) (Q + j*nb01);
199#pragma unroll
200            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
201                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
202
203                __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
204                if (ncols == 1 || ic0 + j < int(ne01.z)) {
205                    ggml_cuda_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);
206                    ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
207                }
208#pragma unroll
209                for (int i1 = 0; i1 < cpy_ne; ++i1) {
210                    Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
211                }
212            }
213#pragma unroll
214            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
215                Q_reg[j][k] *= scale_h2;
216            }
217        }
218#else
219#pragma unroll
220        for (int j = 0; j < ncols; ++j) {
221            const float2 * Q_j = (const float2 *) (Q + j*nb01);
222#pragma unroll
223            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
224                const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
225                if (ncols == 1 || ic0 + j < int(ne01.z)) {
226                    ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);
227                    ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
228                }
229            }
230#pragma unroll
231            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
232                Q_reg[j][k].x *= scale;
233                Q_reg[j][k].y *= scale;
234            }
235        }
236#endif // V_DOT2_F32_F16_AVAILABLE
237    }
238
239    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
240    K     += blockIdx.y*nthreads * nb11;
241    V     += blockIdx.y*nthreads * nb21;
242    maskh += blockIdx.y*nthreads;
243    for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
244             // Increment pointers after each loop:
245             K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
246
247        // Calculate KQ tile and keep track of new maximum KQ values:
248        float KQ_reg[ncols]; // KQ in registers.
249
250        float KQ_max_new[ncols];
251#pragma unroll
252        for (int j = 0; j < ncols; ++j) {
253            KQ_max_new[j] = KQ_max[j];
254        }
255
256#pragma unroll
257        for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
258            const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
259
260#pragma unroll
261            for (int j = 0; j < ncols; ++j) {
262                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
263                sum = warp_reduce_sum<nthreads_KQ>(sum);
264
265                if (use_logit_softcap) {
266                    sum = logit_softcap*tanhf(sum);
267                }
268
269                if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
270                    sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
271                }
272
273                KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
274
275                if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
276                    KQ_reg[j] = sum;
277                }
278            }
279        }
280
281#pragma unroll
282        for (int j = 0; j < ncols; ++j) {
283#pragma unroll
284            for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
285                KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
286            }
287            const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
288            KQ_max[j] = KQ_max_new[j];
289
290            KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
291            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
292            KQ[j*nthreads + tid] = KQ_reg[j];
293
294#ifdef V_DOT2_F32_F16_AVAILABLE
295            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
296#pragma unroll
297            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
298                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
299            }
300#else
301#pragma unroll
302            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
303                VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
304                VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
305            }
306#endif // V_DOT2_F32_F16_AVAILABLE
307        }
308
309#ifndef GGML_USE_HIP
310        __syncwarp();
311#endif // GGML_USE_HIP
312
313#pragma unroll
314        for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
315            const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
316
317#ifdef V_DOT2_F32_F16_AVAILABLE
318            half2 KQ_k[ncols];
319#pragma unroll
320            for (int j = 0; j < ncols; ++j) {
321                KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
322            }
323#pragma unroll
324            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
325                half2 tmp[V_rows_per_thread/2];
326                dequantize_V(V + k*nb21, tmp,
327                    2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
328#pragma unroll
329                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
330#pragma unroll
331                    for (int j = 0; j < ncols; ++j) {
332                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
333                    }
334                }
335            }
336#else
337            float KQ_k[ncols];
338#pragma unroll
339            for (int j = 0; j < ncols; ++j) {
340                KQ_k[j] = KQ[j*nthreads + k];
341            }
342#pragma unroll
343            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
344                float2 tmp[V_rows_per_thread/2];
345                dequantize_V(V + k*nb21, tmp,
346                    2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
347#pragma unroll
348                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
349#pragma unroll
350                    for (int j = 0; j < ncols; ++j) {
351                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
352                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
353                    }
354                }
355            }
356#endif // V_DOT2_F32_F16_AVAILABLE
357        }
358    }
359
360    if (sinks && blockIdx.y == 0) {
361        const float sink = ((const float *) sinks)[head];
362
363#pragma unroll
364        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
365            const int j = j0 + threadIdx.y;
366
367            if (j0 + nwarps > ncols && j >= ncols) {
368                break;
369            }
370
371            const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
372            const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
373            KQ_max[j] = kqmax_new_j;
374
375            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
376
377#ifdef V_DOT2_F32_F16_AVAILABLE
378            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
379#pragma unroll
380            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
381                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
382            }
383#else
384#pragma unroll
385            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
386                VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
387                VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
388            }
389#endif // V_DOT2_F32_F16_AVAILABLE
390        }
391    }
392
393    __shared__ float KQ_max_shared[ncols][WARP_SIZE];
394    __shared__ float KQ_sum_shared[ncols][WARP_SIZE];
395#pragma unroll
396    for (int j = 0; j < ncols; ++j) {
397        if (threadIdx.y == 0) {
398            KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
399            KQ_sum_shared[j][threadIdx.x] = 0.0f;
400        }
401    }
402
403    __syncthreads();
404
405#pragma unroll
406    for (int j = 0; j < ncols; ++j) {
407        if (threadIdx.x == 0) {
408            KQ_max_shared[j][threadIdx.y] = KQ_max[j];
409        }
410    }
411    __syncthreads();
412
413#pragma unroll
414    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
415        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
416            break;
417        }
418
419        float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
420        kqmax_new = warp_reduce_max(kqmax_new);
421        const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
422        KQ_max[j_VKQ] = kqmax_new;
423
424#ifdef V_DOT2_F32_F16_AVAILABLE
425        half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
426            + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
427
428        const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
429#pragma unroll
430        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
431            VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
432        }
433#pragma unroll
434        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
435            const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
436
437            ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
438        }
439#else
440        float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
441            + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
442
443#pragma unroll
444        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
445            VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
446            VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
447        }
448#pragma unroll
449        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
450            const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
451
452            ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,                       &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
453            ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
454        }
455#endif // V_DOT2_F32_F16_AVAILABLE
456
457        KQ_sum[j_VKQ] *= kqmax_scale;
458        KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
459        if (threadIdx.x == 0) {
460            KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
461        }
462
463        __syncthreads();
464
465        if (nthreads <= D || tid < D) {
466            KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
467            KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
468
469#pragma unroll
470            for (int i0 = 0; i0 < D; i0 += nthreads) {
471                float dst_val = 0;
472#pragma unroll
473                for (int w = 0; w < nwarps; ++w) {
474#pragma unroll
475                    for (int v = 0; v < V_cols_per_iter; ++v) {
476                        dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
477                    }
478                }
479                if (gridDim.y == 1) {
480                    dst_val /= KQ_sum[j_VKQ];
481                }
482                dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
483            }
484        }
485
486        if (j_VKQ < ncols-1) {
487            __syncthreads();
488        }
489
490    }
491
492    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
493        dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
494    }
495#else
496    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
497        max_bias, m0, m1, n_head_log2, logit_softcap,
498        ne00, ne01, ne02, ne03,
499              nb01, nb02, nb03,
500        ne10, ne11, ne12, ne13,
501              nb11, nb12, nb13,
502              nb21, nb22, nb23,
503              ne31, ne32, ne33,
504              nb31, nb32, nb33);
505    NO_DEVICE_CODE;
506#endif // FLASH_ATTN_AVAILABLE
507}
508#ifdef __clang__
509#pragma clang diagnostic pop
510#endif // __clang__
511
512template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
513void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
514    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
515
516    const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
517    const int nwarps   = nthreads / WARP_SIZE;
518    fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
519    const bool need_f16_K = type_K == GGML_TYPE_F16;
520    const bool need_f16_V = type_V == GGML_TYPE_F16;
521    constexpr size_t nbytes_shared = 0;
522    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
523}
524
525template <int D, ggml_type type_K, ggml_type type_V>
526void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
527    const ggml_tensor * KQV = dst;
528    const ggml_tensor * Q   = dst->src[0];
529
530    float logit_softcap;
531    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
532
533    if (Q->ne[1] == 1) {
534        constexpr int cols_per_block = 1;
535        if (logit_softcap == 0.0f) {
536            constexpr bool use_logit_softcap = false;
537            ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
538        } else {
539            constexpr bool use_logit_softcap = true;
540            ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
541        }
542        return;
543    }
544
545    constexpr int cols_per_block = 2;
546    if (logit_softcap == 0.0f) {
547        constexpr bool use_logit_softcap = false;
548        ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
549    } else {
550        constexpr bool use_logit_softcap = true;
551        ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
552    }
553}
554
555#define DECL_FATTN_VEC_CASE(D, type_K, type_V)                              \
556    template void ggml_cuda_flash_attn_ext_vec_case                         \
557    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
558
559#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K)             \
560    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16);  \
561    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
562    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
563    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
564    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
565    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
566
567EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
568EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
569EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
570EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
571EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
572EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
573
574EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
575EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
576EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
577EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
578EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
579EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
580
581EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
582EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
583EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
584EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
585EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
586EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)