1#include "common.cuh"
2#include "cp-async.cuh"
3#include "mma.cuh"
4#include "fattn-common.cuh"
5
6using namespace ggml_cuda_mma;
7
8// Config options for the MMA kernel.
9// Should not affect results, only speed/register pressure/shared memory use.
10struct fattn_mma_config {
11 int nthreads; // Number of threads per CUDA block.
12 int occupancy; // Targeted occupancy for the MMA kernel.
13 int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
14 int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
15 int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
16 int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
17 int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
18 bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
19
20 constexpr __host__ __device__ fattn_mma_config(
21 int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
22 nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
23 nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
24};
25
26#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
27 if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
28 static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
29 static_assert( (occupancy_) <= 8, "bad occupancy"); \
30 static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
31 static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
32 static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
33 static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
34 static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
35 return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
36 } \
37
38static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
39 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
40 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
41 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
42 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
43
44 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
45 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
46 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
47 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
48
49 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
50 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
51 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
52 GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
53
54 GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
55 GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
56 GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
57 GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
58
59 GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
60 GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
61 GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
62 GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
63
64 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
65 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
66 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
67 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
68
69 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
70 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
71 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
72 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
73
74 return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
75}
76
77static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
78 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
79 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
80 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
81 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
82
83 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
84 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
85 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
86 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
87
88 return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
89}
90
91static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
92 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
93 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
94 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
95 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
96
97 // TODO tune specifically for Volta
98 return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
99}
100
101static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
102 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
103 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
104 GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
105
106 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
107 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
108 GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
109
110 // TODO tune specifically for RDNA
111 return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
112}
113
114static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
115 if (ampere_mma_available(cc)) {
116 return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
117 }
118 if (turing_mma_available(cc)) {
119 return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
120 }
121 if (amd_wmma_available(cc)) {
122 return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
123 }
124 GGML_ASSERT(volta_mma_available(cc));
125 return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
126}
127
128static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
129#if defined(AMPERE_MMA_AVAILABLE)
130 return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
131#elif defined(TURING_MMA_AVAILABLE)
132 return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
133#elif defined(VOLTA_MMA_AVAILABLE)
134 return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
135#elif defined(AMD_WMMA_AVAILABLE)
136 return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
137#else
138 GGML_UNUSED_VARS(DKQ, DV, ncols);
139 return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
140#endif // defined(AMPERE_MMA_AVAILABLE)
141}
142
143static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
144 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
145}
146
147static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
148 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
149}
150
151static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
152 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
153}
154
155static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
156 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
157}
158
159static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
160 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
161}
162
163static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
164 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
165}
166
167static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
168 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
169}
170
171static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
172 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
173}
174
175static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
176 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
177}
178
179static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
180 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
181}
182
183static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
184 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
185}
186
187static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
188 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
189}
190
191static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
192 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
193}
194
195static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
196 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
197}
198
199static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
200 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
201}
202
203static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
204 return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
205}
206
207static constexpr __device__ int get_cols_per_thread() {
208#if defined(AMD_WMMA_AVAILABLE)
209 return 1; // RDNA has a single column.
210#else
211 return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
212#endif // defined(AMD_WMMA_AVAILABLE)
213}
214
215static __host__ int get_cols_per_warp(const int cc) {
216 if (turing_mma_available(cc) || amd_wmma_available(cc)) {
217 return 16;
218 } else {
219 // Volta
220 return 32;
221 }
222}
223
224// ------------------------------------------------------------------------------------------------------------------
225
226static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
227 return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
228}
229
230static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
231#ifdef CP_ASYNC_AVAILABLE
232 return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
233#else
234 GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
235 return 0;
236#endif // CP_ASYNC_AVAILABLE
237}
238
239// ------------------------------------------------------------------------------------------------------------------
240
241template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
242static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
243 const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
244 // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
245 // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
246 if constexpr (use_cp_async) {
247 static_assert(!oob_check, "OOB check not compatible with cp_async");
248 constexpr int preload = 64;
249 constexpr int h2_per_chunk = 16/sizeof(half2);
250 const int chunks_per_row = D2 / h2_per_chunk;
251
252 const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
253
254 auto load = [&] __device__ (auto n) {
255 const int stride_k = WARP_SIZE >> n;
256 const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
257 const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
258 const int stride_i = WARP_SIZE / stride_k;
259
260 if (k0_start == k0_stop) {
261 return;
262 }
263
264#pragma unroll
265 for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
266 const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
267
268 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
269 break;
270 }
271
272#pragma unroll
273 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
274 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
275
276 cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
277 }
278 }
279 };
280 // 1: max 32*16=512 bytes, 256 half
281 // 2: max 16*16=256 bytes, 128 half
282 // 3: max 8*16=128 bytes, 64 half
283 // 4: max 4*16= 64 bytes, 32 half
284 // 5: max 2*16= 32 bytes, 16 half
285 // 6: max 1*16= 16 bytes, 8 half
286 ggml_cuda_unroll<6>{}(load);
287 } else {
288 // TODO use ggml_cuda_memcpy_1
289 auto load = [&] __device__ (const int n) {
290 const int stride_k = WARP_SIZE >> n;
291 const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
292 const int k0_stop = D2 - D2 % (1*stride_k);
293 const int stride_i = WARP_SIZE / stride_k;
294
295 if (k0_start == k0_stop) {
296 return;
297 }
298
299#pragma unroll
300 for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
301 const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
302
303 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
304 break;
305 }
306
307#pragma unroll
308 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
309 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
310
311 tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
312 }
313 }
314 };
315 // 1: max 32* 4=128 bytes, 64 half
316 // 2: max 16* 4= 64 bytes, 32 half
317 // 3: max 8* 4= 32 bytes, 16 half
318 // 4: max 4* 4= 16 bytes, 8 half
319 ggml_cuda_unroll<4>{}(load);
320 }
321}
322
323template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
324static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
325 const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
326 const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
327 if constexpr (use_cp_async) {
328 static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
329 static_assert(!oob_check, "OOB check incompatible with cp_async");
330 constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
331 constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
332 constexpr int stride_j = nwarps * cols_per_warp;
333
334 const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
335
336#pragma unroll
337 for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
338 const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
339 const int j_vram = fastmodulo(j0 + j_sram, ne01);
340
341 if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
342 break;
343 }
344
345 const int i = 8 * (threadIdx.x % (nbatch_fa/8));
346
347 cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
348 }
349 } else if constexpr (oob_check) {
350#pragma unroll
351 for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
352 const int j_sram = j1 + threadIdx.y;
353 const int j_vram = fastmodulo(j0 + j_sram, ne01);
354
355 if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
356 break;
357 }
358
359#pragma unroll
360 for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
361 const int i = i0 + threadIdx.x;
362
363 tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
364 }
365 }
366 } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
367 constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
368 constexpr int stride_j = nwarps * cols_per_warp;
369#pragma unroll
370 for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
371 const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
372 const int j_vram = fastmodulo(j0 + j_sram, ne01);
373
374 if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
375 break;
376 }
377
378 const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
379
380 ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
381 }
382 } else {
383#pragma unroll
384 for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
385 const int j_sram = j1 + threadIdx.y;
386 const int j_vram = fastmodulo(j0 + j_sram, ne01);
387
388 if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
389 break;
390 }
391
392#pragma unroll
393 for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
394 const int i = i0 + 2*threadIdx.x;
395
396 ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
397 }
398 }
399 }
400}
401
402template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
403 bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
404 typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
405static __device__ __forceinline__ void flash_attn_ext_f16_iter(
406 const float2 * const __restrict__ Q_f2,
407 const half2 * const __restrict__ K_h2,
408 const half2 * const __restrict__ V_h2,
409 const half * const __restrict__ mask_h,
410 float2 * const __restrict__ dstk,
411 float2 * const __restrict__ dstk_fixup,
412 const float scale,
413 const float slope,
414 const float logit_softcap,
415 const uint3 ne01,
416 const int ne02,
417 const int stride_K,
418 const int stride_V,
419 const int stride_mask,
420 half2 * const __restrict__ tile_Q,
421 half2 * const __restrict__ tile_K,
422 half2 * const __restrict__ tile_V,
423 half * const __restrict__ tile_mask,
424 T_B_KQ * const __restrict__ Q_B,
425 T_C_VKQ * const __restrict__ VKQ_C,
426 float * const __restrict__ KQ_max,
427 float * const __restrict__ KQ_rowsum,
428 const int jt,
429 const int kb0,
430 const int k_VKQ_sup) {
431#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
432 constexpr int ncols = ncols1 * ncols2;
433 constexpr int cols_per_warp = T_B_KQ::I;
434 constexpr int cols_per_thread = get_cols_per_thread();
435 constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
436 constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
437 constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
438 constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
439 constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
440 constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
441
442 constexpr int stride_tile_Q = DKQ/2 + 4;
443 constexpr int stride_tile_K = nbatch_K2 + 4;
444
445 constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
446
447 const int k_VKQ_0 = kb0 * nbatch_fa;
448#if defined(TURING_MMA_AVAILABLE)
449 T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
450#elif defined(AMD_WMMA_AVAILABLE)
451 T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
452#else // Volta
453 T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
454#endif // defined(TURING_MMA_AVAILABLE)
455
456 if constexpr (nstages > 1) {
457 static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
458 static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
459 static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
460 constexpr bool use_cp_async = true;
461 cp_async_wait_all();
462 __syncthreads();
463 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
464 (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
465 } else {
466 constexpr bool use_cp_async = nstages == 1;
467 if (ncols2 > 1 || mask_h) {
468 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
469 (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
470 }
471 }
472
473 // For MLA K and V have the same data.
474 // Therefore, iterate over K in reverse and later re-use the data if possible.
475#pragma unroll
476 for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
477 const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
478 const int k0_diff = k0_stop - k0_start;
479
480 if constexpr (nstages <= 1) {
481 constexpr bool use_cp_async = nstages == 1;
482 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
483 (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
484 if (use_cp_async) {
485 cp_async_wait_all();
486 }
487 __syncthreads();
488 }
489
490 // Calculate tile of KQ:
491 if constexpr (Q_in_reg) {
492#pragma unroll
493 for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
494 const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
495#pragma unroll
496 for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
497 T_A_KQ K_A;
498 load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
499 if constexpr (cols_per_warp == 8) {
500 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
501 } else {
502 // Wide version of KQ_C is column-major
503#if defined(AMD_WMMA_AVAILABLE)
504 // RDNA matrix C is column-major.
505 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
506#else
507 // swap A and B for CUDA.
508 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
509#endif // defined(AMD_WMMA_AVAILABLE)
510 }
511 }
512 }
513 } else {
514#pragma unroll
515 for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
516 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
517
518#pragma unroll
519 for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
520 const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
521
522 T_A_KQ K_A;
523 load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
524
525 if constexpr (cols_per_warp == 8) {
526 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
527 } else {
528 // Wide version of KQ_C is column-major
529#if defined(AMD_WMMA_AVAILABLE)
530 // RDNA matrix C is column-major.
531 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
532#else
533 // swap A and B for CUDA.
534 mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
535#endif // defined(AMD_WMMA_AVAILABLE)
536 }
537 }
538 }
539 }
540
541 if constexpr (nstages <= 1) {
542 __syncthreads(); // Only needed if tile_K == tile_V.
543 }
544 }
545
546 if (use_logit_softcap) {
547 constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
548 static_assert(nbatch_fa % stride == 0, "bad loop size");
549#pragma unroll
550 for (int i = 0; i < nbatch_fa/stride; ++i) {
551#pragma unroll
552 for (int l = 0; l < T_C_KQ::ne; ++l) {
553 KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
554 }
555 }
556 }
557
558 float KQ_max_new[cols_per_thread];
559#pragma unroll
560 for (int col = 0; col < cols_per_thread; ++col) {
561 KQ_max_new[col] = KQ_max[col];
562 }
563 float KQ_rowsum_add[cols_per_thread] = {0.0f};
564
565 if constexpr (cols_per_warp == 8) {
566 if (ncols2 > 1 || mask_h) {
567#pragma unroll
568 for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
569 const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
570#pragma unroll
571 for (int l = 0; l < T_C_KQ::ne; ++l) {
572 const int i = i0 + T_C_KQ::get_i(l);
573 const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
574
575 KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
576 }
577 }
578 }
579
580 // Calculate softmax for each KQ column using the current max. value.
581 // The divisor is stored in KQ_rowsum and will be applied at the end.
582 static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
583#pragma unroll
584 for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
585#pragma unroll
586 for (int l = 0; l < T_C_KQ::ne; ++l) {
587 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
588#if defined(AMD_WMMA_AVAILABLE)
589 constexpr int KQ_idx = 0;
590#else
591 // Turing + Volta:
592 const int KQ_idx = l % 2;
593#endif // defined(AMD_WMMA_AVAILABLE)
594 KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
595 }
596 }
597 }
598
599 // Values per KQ column are spread across 8 threads:
600#pragma unroll
601 for (int col = 0; col < cols_per_thread; ++col) {
602#pragma unroll
603 for (int offset = 16; offset >= 4; offset >>= 1) {
604 KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
605 }
606 }
607
608 static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
609#pragma unroll
610 for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
611#pragma unroll
612 for (int l = 0; l < T_C_KQ::ne; ++l) {
613 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
614#if defined(AMD_WMMA_AVAILABLE)
615 constexpr int KQ_idx = 0;
616#else
617 // Turing + Volta:
618 const int KQ_idx = l % 2;
619#endif // defined(AMD_WMMA_AVAILABLE)
620 KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
621 KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
622 } else {
623 KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
624 }
625 }
626 }
627 } else { // not Turing mma or T_B_KQ::I > 8
628 if (ncols2 > 1 || mask_h) {
629#pragma unroll
630 for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
631 const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
632#pragma unroll
633 for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
634 const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
635 const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
636
637 const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
638 KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
639 KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
640 }
641 }
642 }
643
644 // Calculate softmax for each KQ column using the current max. value.
645 // The divisor is stored in KQ_rowsum and will be applied at the end.
646 static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
647#pragma unroll
648 for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
649#pragma unroll
650 for (int l = 0; l < T_C_KQ::ne; ++l) {
651 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
652#if defined(AMD_WMMA_AVAILABLE)
653 constexpr int KQ_idx = 0;
654#else
655 // Turing + Volta:
656 const int KQ_idx = (l/2) % 2;
657#endif // defined(AMD_WMMA_AVAILABLE)
658 KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
659 }
660 }
661 }
662
663#pragma unroll
664 for (int col = 0; col < cols_per_thread; ++col) {
665#if defined(TURING_MMA_AVAILABLE)
666 // Values per KQ column are spread across 4 threads:
667 constexpr int offset_first = 2;
668 constexpr int offset_last = 1;
669#elif defined(AMD_WMMA_AVAILABLE)
670 // Values per KQ column are spread across 2 threads:
671 constexpr int offset_first = 16;
672 constexpr int offset_last = 16;
673#else // Volta
674 // Values per KQ column are spread across 2 threads:
675 constexpr int offset_first = 2;
676 constexpr int offset_last = 2;
677#endif // defined(TURING_MMA_AVAILABLE)
678#pragma unroll
679 for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
680 KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
681 }
682 }
683
684 static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
685#pragma unroll
686 for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
687#pragma unroll
688 for (int l = 0; l < T_C_KQ::ne; ++l) {
689 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
690#if defined(AMD_WMMA_AVAILABLE)
691 constexpr int KQ_idx = 0;
692#else
693 // Turing + Volta:
694 const int KQ_idx = (l/2) % 2;
695#endif // defined(AMD_WMMA_AVAILABLE)
696 KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
697 KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
698 } else {
699 KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
700 }
701 }
702 }
703 }
704
705 {
706 float KQ_max_scale[cols_per_thread];
707#pragma unroll
708 for (int col = 0; col < cols_per_thread; ++col) {
709 const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
710 KQ_max_scale[col] = expf(KQ_max_diff);
711 KQ_max[col] = KQ_max_new[col];
712
713 *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
714
715 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
716 KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
717 }
718
719#if defined(TURING_MMA_AVAILABLE)
720 if constexpr (cols_per_warp == 8) {
721 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
722#pragma unroll
723 for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
724#pragma unroll
725 for (int l = 0; l < T_C_VKQ::ne; ++l) {
726 VKQ_C[i].x[l] *= KQ_max_scale_h2;
727 }
728 }
729 } else {
730#pragma unroll
731 for (int col = 0; col < cols_per_thread; ++col) {
732 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
733#pragma unroll
734 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
735#pragma unroll
736 for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
737 VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
738 }
739 }
740 }
741 }
742#elif defined(AMD_WMMA_AVAILABLE)
743 const half2 KQ_max_scale_h2 = make_half2(
744 KQ_max_scale[0], KQ_max_scale[0]);
745#pragma unroll
746 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
747#pragma unroll
748 for (int l = 0; l < T_C_VKQ::ne; ++l) {
749 VKQ_C[i].x[l] *= KQ_max_scale_h2;
750 }
751 }
752#else // Volta
753 const half2 KQ_max_scale_h2 = make_half2(
754 KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
755#pragma unroll
756 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
757#pragma unroll
758 for (int l = 0; l < T_C_VKQ::ne; ++l) {
759 VKQ_C[i].x[l] *= KQ_max_scale_h2;
760 }
761 }
762#endif // defined(TURING_MMA_AVAILABLE)
763 }
764
765 // Convert KQ C tiles into B tiles for VKQ calculation:
766 T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
767 static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
768 if constexpr (cols_per_warp == 8) {
769#pragma unroll
770 for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
771 B[k] = get_transposed(get_half2(KQ_C[k]));
772 }
773 } else {
774 for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
775 B[k] = get_half2(KQ_C[k]);
776 }
777 }
778
779 if constexpr (nstages > 1) {
780 static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
781 // Preload K tile for next iteration:
782 constexpr bool use_cp_async = true;
783 cp_async_wait_all();
784 __syncthreads();
785 if (!last_iter) {
786 if (ncols2 > 1 || mask_h) {
787 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
788 (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
789 }
790 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
791 (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
792 }
793 }
794
795
796#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
797 T_A_VKQ A_identity;
798 make_identity_mat(A_identity);
799#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
800
801 // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
802#pragma unroll
803 for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
804 static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
805 const int i0_stop = i0_start + 2*nbatch_V2;
806 const int i0_diff = i0_stop - i0_start;
807
808 if constexpr (nstages <= 1) {
809 if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
810 constexpr bool use_cp_async = nstages == 1;
811 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
812 (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
813 if (use_cp_async) {
814 cp_async_wait_all();
815 }
816 __syncthreads();
817 }
818 }
819 const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
820
821#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
822 constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
823#pragma unroll
824 for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
825 static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
826#pragma unroll
827 for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
828 const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
829
830 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
831#if defined(LDMATRIX_TRANS_AVAILABLE)
832 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
833#else
834 // TODO: Try to transpose tile_V when loading gmem to smem.
835 // Use mma to transpose T_A_VKQ for RDNA.
836 T_A_VKQ A_trans;
837 load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
838 mma(A, A_trans, A_identity);
839#endif // defined(TURING_MMA_AVAILABLE)
840 if constexpr (T_B_KQ::I == 8) {
841 mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
842 } else {
843 // Wide version of VKQ_C is column-major.
844#if defined(AMD_WMMA_AVAILABLE)
845 // RDNA matrix C is column-major.
846 mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
847#else
848 // swap A and B for CUDA.
849 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
850#endif // defined(AMD_WMMA_AVAILABLE)
851 }
852 }
853 }
854#else // Volta
855 constexpr int i0_stride = 2*T_C_VKQ::J;
856#pragma unroll
857 for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
858 static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
859 static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
860#pragma unroll
861 for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
862 const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
863
864 T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
865 load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
866 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
867 }
868 }
869#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
870
871 if constexpr (nstages <= 1) {
872 __syncthreads(); // Only needed if tile_K == tile_V.
873 }
874 }
875#else
876 GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
877 scale, slope, logit_softcap, ne01, ne02,
878 stride_K, stride_V, stride_mask,
879 tile_Q, tile_K, tile_V, tile_mask,
880 Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
881 NO_DEVICE_CODE;
882#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
883}
884
885#if defined(TURING_MMA_AVAILABLE)
886template<int ncols> struct mma_tile_sizes {
887 using T_A_KQ = tile<16, 8, half2>; // row-major
888 using T_B_KQ = tile<16, 8, half2>; // column-major
889 using T_C_KQ = tile<16, 16, float>; // column-major
890 using T_A_VKQ = tile<16, 8, half2>; // row-major
891 using T_B_VKQ = tile<16, 8, half2>; // column-major
892 using T_C_VKQ = tile<16, 8, half2>; // column-major
893};
894template<> struct mma_tile_sizes<8> {
895 using T_A_KQ = tile<16, 8, half2>; // row-major
896 using T_B_KQ = tile< 8, 8, half2>; // column-major
897 using T_C_KQ = tile<16, 8, float>; // row-major
898 using T_A_VKQ = tile<16, 8, half2>; // row-major
899 using T_B_VKQ = tile< 8, 8, half2>; // column-major
900 using T_C_VKQ = tile<16, 4, half2>; // row-major
901};
902#elif defined(AMD_WMMA_AVAILABLE)
903template<int ncols> struct mma_tile_sizes {
904 using T_A_KQ = tile<16, 8, half2>; // row-major
905 using T_B_KQ = tile<16, 8, half2>; // column-major
906 using T_C_KQ = tile<16, 16, float>; // column-major
907 using T_A_VKQ = tile<16, 8, half2>; // row-major
908 using T_B_VKQ = tile<16, 8, half2>; // column-major
909 using T_C_VKQ = tile<16, 8, half2>; // column-major
910};
911#else // Volta
912template<int ncols> struct mma_tile_sizes {
913 using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
914 using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
915 using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
916 using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
917 using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
918 using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
919};
920#endif // defined(TURING_MMA_AVAILABLE)
921
922template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
923static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
924 const float2 * const __restrict__ Q_f2,
925 const half2 * const __restrict__ K_h2,
926 const half2 * const __restrict__ V_h2,
927 const half * const __restrict__ mask_h,
928 const float * const __restrict__ sinks_f,
929 float2 * const __restrict__ dstk,
930 float2 * const __restrict__ dstk_fixup,
931 const float scale,
932 const float slope,
933 const float logit_softcap,
934 const uint3 ne01,
935 const int ne02,
936 const int gqa_ratio,
937 const int ne11,
938 const int stride_Q1,
939 const int stride_Q2,
940 const int stride_K,
941 const int stride_V,
942 const int stride_mask,
943 const int jt,
944 const int zt_gqa,
945 const int kb0_start,
946 const int kb0_stop) {
947#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
948 //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
949
950 constexpr int ncols = ncols1 * ncols2;
951 using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
952 using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
953 using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
954 using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
955 using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
956 using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
957
958 constexpr int cols_per_warp = T_B_KQ::I;
959 constexpr int cols_per_thread = get_cols_per_thread();
960 constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
961 constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
962 constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
963 constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
964 constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
965 constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
966 constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
967
968 if (cols_per_warp > ncols) {
969 NO_DEVICE_CODE;
970 return;
971 }
972
973 static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
974
975 constexpr int stride_tile_Q = DKQ/2 + 4;
976 constexpr int stride_tile_K = nbatch_K2 + 4;
977
978 constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
979 constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
980
981 extern __shared__ half2 tile_Q[];
982 half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
983 half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
984 half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
985
986 T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
987#if defined(TURING_MMA_AVAILABLE)
988 T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
989#elif defined(AMD_WMMA_AVAILABLE)
990 T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
991#else // Volta
992 T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
993#endif // defined(TURING_MMA_AVAILABLE)
994
995 float KQ_rowsum[cols_per_thread] = {0.0f};
996 float KQ_max[cols_per_thread];
997#pragma unroll
998 for (int col = 0; col < cols_per_thread; ++col) {
999 KQ_max[col] = -FLT_MAX/2.0f;
1000 }
1001
1002 // Load Q data into tile_Q, either temporarily or permanently.
1003 // Q in registers is faster, but register pressure is the biggest bottleneck.
1004 // The loading is done with decreasing granularity for D for better memory bandwidth.
1005 const half2 scale_h2 = make_half2(scale, scale);
1006#pragma unroll
1007 for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1008 const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
1009 const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
1010 const int stride_jc = WARP_SIZE / stride_k;
1011
1012 if (k0_start == k0_stop) {
1013 continue;
1014 }
1015
1016#pragma unroll
1017 for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
1018 const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1019
1020 if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
1021 break;
1022 }
1023
1024 const int j = jc / ncols2;
1025 const int c = jc % ncols2;
1026
1027 if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
1028#pragma unroll
1029 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1030 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1031
1032 const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
1033 tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
1034 }
1035 } else {
1036#pragma unroll
1037 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1038 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1039
1040 tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
1041 }
1042 }
1043 }
1044 }
1045
1046 __syncthreads();
1047
1048 if (Q_in_reg) {
1049 const int j0 = (threadIdx.y / np) * cols_per_warp;
1050
1051#pragma unroll
1052 for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
1053 load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
1054 }
1055 }
1056
1057 __syncthreads();
1058
1059 int kb0 = kb0_start;
1060
1061 // Preload mask and K data for first iteration when using cp_async with multiple stages:
1062 if constexpr (nstages > 1) {
1063 static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
1064 constexpr bool use_cp_async = true;
1065 constexpr bool oob_check = false;
1066 constexpr int k_VKQ_sup = nbatch_fa;
1067 if (ncols2 > 1 || mask_h) {
1068 flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
1069 (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
1070 }
1071 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
1072 (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
1073 }
1074
1075 // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
1076 if constexpr (ncols2 == 1) {
1077 constexpr bool oob_check = true;
1078 for (; kb0 < kb0_stop-1; ++kb0) {
1079 constexpr bool last_iter = false;
1080 constexpr int k_VKQ_sup = nbatch_fa;
1081 flash_attn_ext_f16_iter
1082 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1083 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1084 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1085 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1086 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1087 }
1088 constexpr bool last_iter = true;
1089 const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
1090 flash_attn_ext_f16_iter
1091 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1092 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1093 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1094 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1095 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1096 } else {
1097 constexpr bool oob_check = false;
1098 for (; kb0 < kb0_stop-1; ++kb0) {
1099 constexpr bool last_iter = false;
1100 constexpr int k_VKQ_sup = nbatch_fa;
1101 flash_attn_ext_f16_iter
1102 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1103 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1104 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1105 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1106 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1107 }
1108 constexpr bool last_iter = true;
1109 constexpr int k_VKQ_sup = nbatch_fa;
1110 flash_attn_ext_f16_iter
1111 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
1112 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
1113 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
1114 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
1115 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
1116 }
1117
1118 // With multi-stage loading there is no __syncthreads at the end of the iter,
1119 // there can be a race condition on shared memory access for combining/writing back results.
1120 if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
1121 __syncthreads();
1122 }
1123
1124 // Finally, sum up partial KQ rowsums.
1125 {
1126#if defined(TURING_MMA_AVAILABLE)
1127 // The partial sums are spread across 8/4 threads.
1128 constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
1129 constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
1130#elif defined(AMD_WMMA_AVAILABLE)
1131 // The partial sums are spread across 2 threads.
1132 constexpr int offset_first = 16;
1133 constexpr int offset_last = 16;
1134#else // Volta
1135 // The partial sums are spread across 2 threads.
1136 constexpr int offset_first = 2;
1137 constexpr int offset_last = 2;
1138#endif // defined(TURING_MMA_AVAILABLE)
1139#pragma unroll
1140 for (int col = 0; col < cols_per_thread; ++col) {
1141#pragma unroll
1142 for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
1143 KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
1144 }
1145 }
1146 }
1147
1148 // If attention sinks are used, potentially re-scale if KQ_max is small.
1149 // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
1150 // so it's being done unconditionally for every thread.
1151 if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
1152 float KQ_max_scale[cols_per_thread];
1153#pragma unroll
1154 for (int col = 0; col < cols_per_thread; ++col) {
1155 const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
1156 const float sink = sinks_f[jc % ncols2];
1157
1158 const float KQ_max_new = fmaxf(KQ_max[col], sink);
1159 const float KQ_max_diff = KQ_max[col] - KQ_max_new;
1160 KQ_max_scale[col] = expf(KQ_max_diff);
1161 KQ_max[col] = KQ_max_new;
1162
1163 *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
1164
1165 const float KQ_max_add = expf(sink - KQ_max_new);
1166 KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
1167 }
1168
1169#if defined(TURING_MMA_AVAILABLE)
1170 if constexpr (cols_per_warp == 8) {
1171 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
1172#pragma unroll
1173 for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
1174#pragma unroll
1175 for (int l = 0; l < T_C_VKQ::ne; ++l) {
1176 VKQ_C[i].x[l] *= KQ_max_scale_h2;
1177 }
1178 }
1179 } else {
1180#pragma unroll
1181 for (int col = 0; col < cols_per_thread; ++col) {
1182 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
1183#pragma unroll
1184 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1185#pragma unroll
1186 for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
1187 VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
1188 }
1189 }
1190 }
1191 }
1192#elif defined(AMD_WMMA_AVAILABLE)
1193 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
1194#pragma unroll
1195 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1196#pragma unroll
1197 for (int l = 0; l < T_C_VKQ::ne; ++l) {
1198 VKQ_C[i].x[l] *= KQ_max_scale_h2;
1199 }
1200 }
1201#else // Volta
1202 const int col = (threadIdx.x / 2) % 2;
1203 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
1204#pragma unroll
1205 for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1206#pragma unroll
1207 for (int l = 0; l < T_C_VKQ::ne; ++l) {
1208 VKQ_C[i].x[l] *= KQ_max_scale_h2;
1209 }
1210 }
1211#endif // defined(TURING_MMA_AVAILABLE)
1212 }
1213
1214 // Combine VKQ accumulator values if np > 1.
1215 // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
1216 // So also write VKQ accumulators to shared memory in column-major format if np == 1.
1217
1218 constexpr int tile_stride = nbatch_combine + 4;
1219 static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
1220
1221 if constexpr (cols_per_warp == 8) {
1222 const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
1223 const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
1224 const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
1225
1226 if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
1227 // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1228 ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1229 }
1230
1231 __syncthreads();
1232
1233 if (np == 1) {
1234 // No combination is needed, the meta data can be directly written from registers to VRAM.
1235 if (needs_fixup && threadIdx.x < T_B_KQ::I) {
1236 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1237 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1238 }
1239 if (is_fixup && threadIdx.x < T_B_KQ::I) {
1240 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1241 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1242 }
1243 }
1244 } else {
1245 // jc_cwm = jc combine write meta
1246 // KQ_cmr = KQ combine max rowsum
1247 // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
1248#if defined(TURING_MMA_AVAILABLE)
1249 const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
1250 const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
1251 const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
1252#elif defined(AMD_WMMA_AVAILABLE)
1253 const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
1254 const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
1255 const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
1256#else // Volta
1257 const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
1258 const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
1259 const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
1260#endif // defined(TURING_MMA_AVAILABLE)
1261
1262 if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
1263 ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1264 }
1265
1266 __syncthreads();
1267
1268 if (np == 1) {
1269 // No combination is needed, the meta data can be directly written from registers to VRAM.
1270 if (needs_fixup && thread_should_write) {
1271 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1272 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1273 }
1274 if (is_fixup && thread_should_write) {
1275 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1276 dstk_fixup_meta[jc_cwm] = KQ_cmr;
1277 }
1278 }
1279 }
1280
1281 if (np > 1 && threadIdx.y % np == 0) {
1282 // Combine the meta data for parallel warps via shared memory.
1283 // Warps with threadIdx.y % np != 0 must NOT return early.
1284 // All threads must return simultaneously to avoid race conditions with work on the next tile.
1285
1286 constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
1287
1288 const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1289 float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
1290 float2 meta[nmeta];
1291#pragma unroll
1292 for (int imeta = 0; imeta < nmeta; ++imeta) {
1293 meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
1294 }
1295
1296 float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
1297#pragma unroll
1298 for (int imeta = 1; imeta < nmeta; ++imeta) {
1299 KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
1300 }
1301#pragma unroll
1302 for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1303 if (offset < WARP_SIZE) {
1304 KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1305 }
1306 }
1307
1308 float KQ_cms[nmeta]; // KQ combine max scale per warp.
1309#pragma unroll
1310 for (int imeta = 0; imeta < nmeta; ++imeta) {
1311 KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
1312 }
1313
1314 float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
1315#pragma unroll
1316 for (int imeta = 1; imeta < nmeta; ++imeta) {
1317 KQ_crs += KQ_cms[imeta]*meta[imeta].y;
1318 }
1319#pragma unroll
1320 for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1321 if (offset < WARP_SIZE) {
1322 KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1323 }
1324 }
1325
1326 __syncthreads();
1327
1328 // Write back combined meta data:
1329#pragma unroll
1330 for (int imeta = 0; imeta < nmeta; ++imeta) {
1331 if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
1332 // Combined KQ max scale + rowsum.
1333 meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1334 }
1335 }
1336
1337 // Combined KQ max + rowsum.
1338 static_assert(cols_per_warp <= WARP_SIZE);
1339 if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1340 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1341 dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1342 }
1343 if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1344 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1345 dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1346 }
1347 } else if (np > 1) {
1348 // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
1349 // Therefore, all other warps also need to execute a __syncthreads().
1350 // Otherwise the points at which warps synchronize with each other would become misaligned.
1351 __syncthreads();
1352 }
1353
1354#pragma unroll
1355 for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
1356 if constexpr (cols_per_warp == 8) {
1357 const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
1358#pragma unroll
1359 for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
1360 const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
1361
1362#pragma unroll
1363 for (int l = 0; l < T_B_KQ::ne; ++l) {
1364 const int k = k1 + T_B_KQ::get_j(l);
1365
1366 tile_Q[jc_cwd*tile_stride + k] = B.x[l];
1367 }
1368 }
1369 } else {
1370 const int j0 = threadIdx.y*cols_per_warp;
1371#pragma unroll
1372 for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1373#pragma unroll
1374 for (int l = 0; l < T_C_VKQ::ne; ++l) {
1375 const int j = j0 + T_C_VKQ::get_i(l);
1376 const int k = k1 + T_C_VKQ::get_j(l);
1377
1378 tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
1379 }
1380 }
1381 }
1382
1383 __syncthreads();
1384
1385 if (np == 1 || threadIdx.y % np == 0) {
1386 // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
1387 // The values after that are for the partial results of the individual blocks.
1388 float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
1389
1390#pragma unroll
1391 for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1392 const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1393 const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
1394 const int stride_jc = WARP_SIZE / stride_k;
1395
1396 if (k0_start == k0_stop) {
1397 continue;
1398 }
1399
1400#pragma unroll
1401 for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
1402 const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1403
1404 if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
1405 break;
1406 }
1407
1408 const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
1409
1410 const int j_dst = jc_dst / ncols2;
1411 const int c_dst = jc_dst % ncols2;
1412
1413 if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
1414 continue;
1415 }
1416
1417 const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
1418#pragma unroll
1419 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1420 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1421
1422 float2 dstk_val = make_float2(0.0f, 0.0f);
1423#pragma unroll
1424 for (int ip = 0; ip < np; ++ip) {
1425 const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
1426 const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
1427 dstk_val.x += dstk_val_add.x*KQ_crs;
1428 dstk_val.y += dstk_val_add.y*KQ_crs;
1429 }
1430
1431 if (!needs_fixup && !is_fixup) {
1432 const float KQ_rowsum_j = meta_j[1];
1433 dstk_val.x /= KQ_rowsum_j;
1434 dstk_val.y /= KQ_rowsum_j;
1435 }
1436
1437 if (is_fixup) {
1438 dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
1439 } else {
1440 dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
1441 }
1442 }
1443 }
1444 }
1445 }
1446 if (np > 1) {
1447 __syncthreads();
1448 }
1449 }
1450#else
1451 GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1452 scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
1453 stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1454 jt, kb0_start, kb0_stop);
1455 NO_DEVICE_CODE;
1456#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
1457}
1458
1459template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
1460__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
1461static __global__ void flash_attn_ext_f16(
1462 const char * __restrict__ Q,
1463 const char * __restrict__ K,
1464 const char * __restrict__ V,
1465 const char * __restrict__ mask,
1466 const char * __restrict__ sinks,
1467 const int * __restrict__ KV_max,
1468 float * __restrict__ dst,
1469 float2 * __restrict__ dst_meta,
1470 const float scale,
1471 const float max_bias,
1472 const float m0,
1473 const float m1,
1474 const uint32_t n_head_log2,
1475 const float logit_softcap,
1476 const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
1477 const int32_t nb01, const int32_t nb02, const int32_t nb03,
1478 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1479 const int32_t nb11, const int32_t nb12, const int64_t nb13,
1480 const int32_t nb21, const int32_t nb22, const int64_t nb23,
1481 const int32_t ne31, const int32_t ne32, const int32_t ne33,
1482 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1483#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
1484
1485 // Skip unused kernel variants for faster compilation:
1486 if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1487 NO_DEVICE_CODE;
1488 return;
1489 }
1490#ifdef VOLTA_MMA_AVAILABLE
1491 if (ncols1*ncols2 < 32) {
1492 NO_DEVICE_CODE;
1493 return;
1494 }
1495#endif // VOLTA_MMA_AVAILABLE
1496
1497#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1498 if (ncols1*ncols2 > 32) {
1499 NO_DEVICE_CODE;
1500 return;
1501 }
1502#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1503
1504#if defined(AMD_WMMA_AVAILABLE)
1505 if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
1506 NO_DEVICE_CODE;
1507 return;
1508 }
1509#endif // defined(AMD_WMMA_AVAILABLE)
1510
1511 constexpr int ncols = ncols1 * ncols2;
1512 constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
1513 constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
1514 constexpr int nwarps = nthreads / WARP_SIZE;
1515
1516 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1517
1518 const int stride_Q1 = nb01 / sizeof(float2);
1519 const int stride_Q2 = nb02 / sizeof(float2);
1520 const int stride_K = nb11 / sizeof(half2);
1521 const int stride_mask = nb31 / sizeof(half);
1522
1523 const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
1524
1525 const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1526 const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1527 const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
1528
1529 // kbc == k block continuous, current index in continuous ijk space.
1530 int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1531 const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1532
1533 // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1534 // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
1535 // In the most general case >2 seams can fall into the same tile.
1536
1537 // kb0 == k start index when in the output tile.
1538 int kb0_start = kbc % iter_k;
1539 int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1540
1541 while (kbc < kbc_stop && kb0_stop == iter_k) {
1542 // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1543 const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1544 const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1545 const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1546 const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1547
1548 const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1549
1550 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1551 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1552 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1553 (const half *) (mask + nb33*(sequence % ne33));
1554 float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1555
1556 const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1557 const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1558
1559 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1560
1561 if (KV_max) {
1562 kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1563 }
1564 constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1565 if (kb0_start == 0) {
1566 constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1567 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1568 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1569 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1570 } else {
1571 constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
1572 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1573 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1574 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1575 }
1576
1577 kbc += iter_k;
1578 kbc -= kbc % iter_k;
1579
1580 kb0_start = 0;
1581 kb0_stop = min(iter_k, kbc_stop - kbc);
1582 }
1583
1584 if (kbc >= kbc_stop) {
1585 return;
1586 }
1587
1588 // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1589 const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1590 const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1591 const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1592 const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1593
1594 const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1595
1596 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1597 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1598 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1599 (const half *) (mask + nb33*(sequence % ne33));
1600 float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1601
1602 const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1603 const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1604
1605 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1606
1607 if (KV_max) {
1608 kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1609 }
1610
1611 constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1612 constexpr bool needs_fixup = false;
1613 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1614 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1615 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1616#else
1617 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1618 max_bias, m0, m1, n_head_log2, logit_softcap,
1619 ne00, ne01, ne02, ne03,
1620 nb01, nb02, nb03,
1621 ne10, ne11, ne12, ne13,
1622 nb11, nb12, nb13,
1623 nb21, nb22, nb23,
1624 ne31, ne32, ne33,
1625 nb31, nb32, nb33);
1626 NO_DEVICE_CODE;
1627#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
1628}
1629
1630template <int DKQ, int DV, int ncols1, int ncols2>
1631void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1632 const ggml_tensor * KQV = dst;
1633 const int id = ggml_cuda_get_device();
1634 const int cc = ggml_cuda_info().devices[id].cc;
1635
1636 constexpr int ncols = ncols1 * ncols2;
1637
1638 const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
1639 const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
1640 const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
1641 const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
1642 const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
1643 const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
1644 const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
1645
1646 const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
1647 const int nwarps = nthreads / WARP_SIZE;
1648
1649 constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
1650
1651 const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1652 const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1653 const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1654 const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
1655 const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1656
1657 const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1658
1659 const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
1660 std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
1661 nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
1662
1663 float logit_softcap;
1664 memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1665
1666#if defined(GGML_USE_HIP)
1667 using fattn_kernel_ptr_t = const void*;
1668#else
1669 using fattn_kernel_ptr_t = fattn_kernel_t;
1670#endif // defined(GGML_USE_HIP)
1671 fattn_kernel_t fattn_kernel;
1672 if (logit_softcap == 0.0f) {
1673 constexpr bool use_logit_softcap = false;
1674 fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1675
1676#if !defined(GGML_USE_MUSA)
1677 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1678 if (!shared_memory_limit_raised[id]) {
1679 CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1680 shared_memory_limit_raised[id] = true;
1681 }
1682#endif // !defined(GGML_USE_MUSA)
1683 } else {
1684 constexpr bool use_logit_softcap = true;
1685 fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1686
1687#if !defined(GGML_USE_MUSA)
1688 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1689 if (!shared_memory_limit_raised[id]) {
1690 CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1691 shared_memory_limit_raised[id] = true;
1692 }
1693#endif // !defined(GGML_USE_MUSA)
1694 }
1695
1696 launch_fattn<DV, ncols1, ncols2>
1697 (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
1698}
1699
1700
1701#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
1702 template void ggml_cuda_flash_attn_ext_mma_f16_case \
1703 <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1704
1705#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
1706 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
1707 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
1708 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
1709 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
1710 extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
1711
1712DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
1713DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
1714DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
1715DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
1716DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
1717DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
1718
1719DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
1720DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
1721DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
1722DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
1723DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
1724DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
1725
1726DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
1727DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
1728DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
1729DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
1730DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
1731DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
1732
1733DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
1734DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
1735DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
1736DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
1737DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
1738DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
1739
1740// The number of viable configurations for Deepseek is very limited:
1741extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1742extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1743extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
1744
1745// For GLM 4.7 Flash
1746extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
1747extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
1748extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
1749extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
1750extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);