1#include "common.cuh"
2#include "fattn-common.cuh"
3#include "fattn-wmma-f16.cuh"
4
5// nbatch_fa == number of KQ rows to process per iteration
6// nbatch_K == number of K columns to load in parallel for KQ calculation
7
8// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
10
11// The ROCm compiler cannot handle templating in __launch_bounds__.
12// As a workaround, define a macro to package the kernel parameters as uint32_t:
13#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
14 if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
15 static_assert((nthreads) <= 512, "bad nthreads"); \
16 static_assert((occupancy) <= 8, "bad occupancy"); \
17 static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \
18 static_assert((nbatch_K) <= 256, "bad nbatch_K"); \
19 return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \
20 } \
21
22static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
23 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40)
24 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40)
25 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40)
26 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40)
27 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40)
28
29 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64)
30 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64)
31 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64)
32 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
33 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
34
35 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
36 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
37 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
38 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
39 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
40
41 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
42 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
43 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
44 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
45 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
46
47 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
48 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
49 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
50 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48)
51 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48)
52
53 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56)
54 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56)
55 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56)
56 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56)
57 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56)
58
59 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64)
60 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64)
61 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64)
62 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
63 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
64
65 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
66 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
67 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
68 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
69 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
70
71 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
72 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
73 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
74
75 return 0;
76}
77
78static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
79 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
80 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
81 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
82 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
83 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
84
85 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64)
86 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64)
87 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64)
88 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
89 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
90
91 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
92 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
93 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
94 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
95 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
96
97 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
98 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
99 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
100 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
101 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
102
103 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
104 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
105 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
106 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
107 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
108
109 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
110 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
111 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
112 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
113 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
114
115 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64)
116 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128)
117 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128)
118 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
119 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
120
121 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
122 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
123 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
124 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
125 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
126
127 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
128 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
129 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
130
131 return 0;
132}
133
134static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
135 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
136 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
137 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
138 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
139 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
140 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
141
142 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
143 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
144 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
145 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
146 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
147 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
148
149 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
150 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
151 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
152 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
153 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
154 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
155
156 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
157 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
158 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
159 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
160 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
161 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
162
163 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
164 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
165 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
166 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
167 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
168 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
169
170 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
171 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
172 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
173 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
174 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
175 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
176
177 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
178 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
179 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
180 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
181 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
182 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
183
184 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
185 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
186 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
187 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
188 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
189
190 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
191 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
192 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
193 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
194
195 return 0;
196}
197
198static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
199 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
200 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
201 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
202 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
203 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
204 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
205
206 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
207 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
208 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
209 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
210 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
211 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
212
213 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
214 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
215 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
216 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
217 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
218 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
219
220 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
221 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
222 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
223 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
224 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
225 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
226
227 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
228 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
229 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
230 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
231 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
232 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
233
234 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
235 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
236 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
237 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
238 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
239 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
240
241 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
242 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
243 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
244 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
245 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
246 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
247
248 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
249 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
250 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
251 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
252 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
253
254 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
255 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
256 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
257 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
258
259 return 0;
260}
261
262static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
263 if (GGML_CUDA_CC_IS_AMD(cc)) {
264 if (GGML_CUDA_CC_IS_RDNA(cc)) {
265 return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
266 }
267 return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
268 }
269 if (fast_fp16_available(cc)) {
270 return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
271 }
272 return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
273}
274
275static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
276#ifdef GGML_USE_HIP
277#ifdef RDNA
278 return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
279#else
280 return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
281#endif // RDNA
282#else
283#ifdef FAST_FP16_AVAILABLE
284 return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
285#else
286 return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
287#endif // FAST_FP16_AVAILABLE
288#endif // GGML_USE_HIP
289}
290
291static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
292 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
293}
294
295static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
296 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
297}
298
299static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
300 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
301}
302
303static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
304 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
305}
306
307static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
308 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
309}
310
311static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
312 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
313}
314
315static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
316 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
317}
318
319static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
320 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
321}
322
323// TODO: deduplicate with mma-f16
324template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
325static __device__ __forceinline__ void flash_attn_tile_load_tile(
326 const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
327 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
328 constexpr int cpy_ne = cpy_nb / 4;
329
330 auto load = [&] __device__ (const int n) {
331 const int stride_j = warp_size >> n;
332
333 if (stride_j == 0) {
334 return;
335 }
336
337 const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
338 const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
339 const int stride_i = warp_size / stride_j;
340
341 if (j0_start == j0_stop) {
342 return;
343 }
344
345#pragma unroll
346 for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
347 const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
348
349 if (i0 + nwarps*stride_i <= I || i < I) {
350#pragma unroll
351 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
352 const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
353
354 const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
355 ggml_cuda_memcpy_1<cpy_nb>(
356 tile_KV + i*(J/2 + J_padding) + j,
357 !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
358 }
359 }
360 }
361 };
362 // 1: max 64*16=512 bytes, 512 half
363 // 2: max 32*16=512 bytes, 256 half
364 // 3: max 16*16=256 bytes, 128 half
365 // 4: max 8*16=128 bytes, 64 half
366 // 5: max 4*16= 64 bytes, 32 half
367 // 6: max 2*16= 32 bytes, 16 half
368 // 7: max 1*16= 16 bytes, 8 half
369 static_assert(J % 8 == 0, "bad J");
370 static_assert((J/2) % cpy_ne == 0, "bad J");
371 ggml_cuda_unroll<7>{}(load);
372}
373
374template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
375static __device__ __forceinline__ void flash_attn_tile_load_tile(
376 const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
377 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
378 constexpr int cpy_ne = cpy_nb / 4;
379
380 auto load = [&] __device__ (const int n) {
381 const int stride_j = warp_size >> n;
382
383 if (stride_j == 0) {
384 return;
385 }
386
387 const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
388 const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
389 const int stride_i = warp_size / stride_j;
390
391 if (j0_start == j0_stop) {
392 return;
393 }
394
395#pragma unroll
396 for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
397 const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
398
399 if (i0 + nwarps*stride_i <= I || i < I) {
400#pragma unroll
401 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
402 const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
403
404 const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
405 __align__(16) half2 tmp_h2[cpy_ne/2];
406 ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
407 tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
408
409 __align__(16) float2 tmp_f2[cpy_ne/2];
410#pragma unroll
411 for (int l = 0; l < cpy_ne/2; ++l) {
412 tmp_f2[l] = __half22float2(tmp_h2[l]);
413 }
414 ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
415 }
416 }
417 }
418 };
419 // 1: max 32*16=512 bytes, 128 float
420 // 2: max 16*16=256 bytes, 64 float
421 // 3: max 8*16=128 bytes, 32 float
422 // 4: max 4*16= 64 bytes, 16 float
423 // 5: max 2*16= 32 bytes, 8 float
424 static_assert(J % 8 == 0, "bad J");
425 static_assert(J % cpy_ne == 0, "bad J");
426 ggml_cuda_unroll<5>{}(load);
427}
428
429// Function that performs a single iteration in for the KQ matrix multiplication:
430template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
431 bool use_logit_softcap, bool oob_check, typename T_vec_dot>
432static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
433 T_vec_dot * const Q_tmp,
434 const half2 * const __restrict__ K_h2,
435 T_vec_dot * const KV_tmp,
436 const int stride_K2,
437 const int k_VKQ_0,
438 const int k_VKQ_sup,
439 const int k_KQ_0,
440 float * KQ_acc) {
441 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
442 constexpr int cpy_ne = cpy_nb / 4;
443
444 constexpr int ncols = ncols1*ncols2;
445 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
446 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
447
448 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
449 (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
450 __syncthreads();
451
452#ifdef FAST_FP16_AVAILABLE
453 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
454#pragma unroll
455 for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
456 __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
457 __align__(16) half2 Q_k[cpw][cpy_ne];
458#else
459 static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
460#pragma unroll
461 for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
462 __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
463 __align__(16) float Q_k[cpw][cpy_ne];
464#endif // FAST_FP16_AVAILABLE
465
466#pragma unroll
467 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
468 const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
469
470#ifdef FAST_FP16_AVAILABLE
471 ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
472#else
473 ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]);
474#endif // FAST_FP16_AVAILABLE
475 }
476#pragma unroll
477 for (int jc0 = 0; jc0 < cpw; ++jc0) {
478 const int jc = jc0 + (threadIdx.y / np)*cpw;
479
480#ifdef FAST_FP16_AVAILABLE
481 ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
482#else
483 ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]);
484#endif // FAST_FP16_AVAILABLE
485 }
486
487#pragma unroll
488 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
489#pragma unroll
490 for (int jc0 = 0; jc0 < cpw; ++jc0) {
491#pragma unroll
492 for (int k = 0; k < cpy_ne; ++k) {
493 ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
494 }
495 }
496 }
497 }
498
499 if (k_KQ_0 + nbatch_K < DKQ) {
500 __syncthreads(); // Sync not needed on last iteration.
501 }
502}
503
504// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
505template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
506 bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
507static __device__ __forceinline__ void flash_attn_tile_iter(
508 T_vec_dot * const Q_tmp,
509 const half2 * const __restrict__ K_h2,
510 const half2 * const __restrict__ V_h2,
511 const half * const __restrict__ mask,
512 const uint3 ne01,
513 const float logit_softcap,
514 const float slope,
515 T_KQ * const KQ,
516 T_vec_dot * const KV_tmp,
517 const int stride_K2,
518 const int stride_V2,
519 const int stride_mask,
520 float * const KQ_max,
521 float * const KQ_sum,
522 T_acc * const VKQ,
523 const int k_VKQ_0,
524 const int k_VKQ_max,
525 const int col_Q_0) {
526 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
527 constexpr int cpy_ne = cpy_nb / 4;
528
529 constexpr int ncols = ncols1*ncols2;
530 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
531 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
532
533 constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
534
535 // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
536 // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
537#ifdef FAST_FP16_AVAILABLE
538 constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
539#else
540 constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
541#endif // FAST_FP16_AVAILABLE
542 static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
543 const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
544
545 float KQ_max_new[cpw];
546#pragma unroll
547 for (int jc0 = 0; jc0 < cpw; ++jc0) {
548 KQ_max_new[jc0] = KQ_max[jc0];
549 }
550
551 float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
552
553 // KQ = K @ Q matrix multiplication:
554 constexpr int nbatch_K_last = DKQ % nbatch_K;
555#pragma unroll
556 for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
557 flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
558 Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
559 }
560 if (nbatch_K_last > 0) {
561 constexpr int k_KQ_0 = DKQ - nbatch_K_last;
562 flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
563 Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
564 }
565
566 // Apply logit softcap + mask, update KQ_max:
567#pragma unroll
568 for (int jc0 = 0; jc0 < cpw; ++jc0) {
569 const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
570
571#pragma unroll
572 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
573 const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
574
575#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
576 // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
577 // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
578 KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
579#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
580
581 if (use_logit_softcap) {
582 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
583 }
584
585 if (!oob_check || i_KQ < k_VKQ_sup) {
586 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
587 slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
588
589 KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
590 }
591 }
592
593 KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
594 }
595
596 if constexpr (np == 1) {
597 __syncthreads();
598 } else {
599 static_assert(cpw == 1, "bad cpw");
600 __shared__ float KQ_max_new_shared[nwarps];
601 if (threadIdx.x == 0) {
602 KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
603 }
604 __syncthreads();
605 KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
606 KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
607 }
608
609 // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
610#pragma unroll
611 for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
612#ifdef FAST_FP16_AVAILABLE
613 __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
614#else
615 __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
616#endif // FAST_FP16_AVAILABLE
617
618#pragma unroll
619 for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
620 const int jc = jc0 + jc1;
621
622 const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
623 KQ_max[jc] = KQ_max_new[jc];
624
625 float KQ_sum_add = 0.0f;
626#pragma unroll
627 for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
628 const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
629 expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
630 KQ_sum_add += val;
631 tmp[i0/(np*warp_size)][jc1] = val;
632 }
633 KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
634
635#ifdef FAST_FP16_AVAILABLE
636 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
637#pragma unroll
638 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
639 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
640 }
641#else
642#pragma unroll
643 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
644 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
645 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
646 }
647#endif // FAST_FP16_AVAILABLE
648 }
649
650#pragma unroll
651 for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
652 const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
653
654 ggml_cuda_memcpy_1<sizeof(tmp[0])>(
655 KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
656 tmp[i0/(np*warp_size)]);
657 }
658 }
659
660 // VKQ = V @ KQ matrix multiplication:
661 static_assert(DV <= DKQ, "bad DV");
662 static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
663 constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
664 static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
665 static_assert(nbatch_V % np == 0, "bad nbatch_V");
666#pragma unroll
667 for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
668 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
669 (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
670 __syncthreads();
671
672#ifdef FAST_FP16_AVAILABLE
673#pragma unroll
674 for (int k1 = 0; k1 < nbatch_V; k1 += np) {
675 __align__(16) half2 V_k[(DVp/2)/warp_size];
676 __align__(16) half2 KQ_k[cpw];
677
678 constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
679#pragma unroll
680 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
681 ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
682 }
683#pragma unroll
684 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
685 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
686
687 __align__(16) half tmp[KQ_cs];
688 ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
689 &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
690#pragma unroll
691 for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
692 KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
693 }
694 }
695
696#pragma unroll
697 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
698#pragma unroll
699 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
700 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
701 }
702 }
703 }
704#else
705#pragma unroll
706 for (int k1 = 0; k1 < nbatch_V; k1 += np) {
707 __align__(16) float2 V_k[(DVp/2)/warp_size];
708 __align__(16) float KQ_k[cpw];
709
710 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
711#pragma unroll
712 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
713 ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
714 }
715#pragma unroll
716 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
717 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
718
719 ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
720 &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
721 }
722
723#pragma unroll
724 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
725#pragma unroll
726 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
727 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
728 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
729 }
730 }
731 }
732#endif // FAST_FP16_AVAILABLE
733
734 __syncthreads();
735 }
736}
737
738template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
739__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
740static __global__ void flash_attn_tile(
741 const char * __restrict__ Q,
742 const char * __restrict__ K,
743 const char * __restrict__ V,
744 const char * __restrict__ mask,
745 const char * __restrict__ sinks,
746 const int * __restrict__ KV_max,
747 float * __restrict__ dst,
748 float2 * __restrict__ dst_meta,
749 const float scale,
750 const float max_bias,
751 const float m0,
752 const float m1,
753 const uint32_t n_head_log2,
754 const float logit_softcap,
755 const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
756 const int32_t nb01, const int32_t nb02, const int32_t nb03,
757 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
758 const int32_t nb11, const int32_t nb12, const int64_t nb13,
759 const int32_t nb21, const int32_t nb22, const int64_t nb23,
760 const int32_t ne31, const int32_t ne32, const int32_t ne33,
761 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
762#ifdef FLASH_ATTN_AVAILABLE
763
764 // Skip unused kernel variants for faster compilation:
765
766 if (
767#ifdef GGML_USE_WMMA_FATTN
768 (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
769#endif // GGML_USE_WMMA_FATTN
770 (use_logit_softcap && !(DV == 128 || DV == 256))
771 ) {
772 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
773 max_bias, m0, m1, n_head_log2, logit_softcap,
774 ne00, ne01, ne02, ne03,
775 nb01, nb02, nb03,
776 ne10, ne11, ne12, ne13,
777 nb11, nb12, nb13,
778 nb21, nb22, nb23,
779 ne31, ne32, ne33,
780 nb31, nb32, nb33);
781 NO_DEVICE_CODE;
782 return;
783 }
784
785 static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
786
787 constexpr int ncols = ncols1*ncols2;
788 constexpr int warp_size = 32;
789 constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
790 constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
791 constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
792
793 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
794
795 const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
796
797 const int sequence = blockIdx.z / (ne02/ncols2);
798 const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
799 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
800 const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
801 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
802 const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
803
804 const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
805
806 const int stride_K2 = nb11 / sizeof(half2);
807 const int stride_V2 = nb21 / sizeof(half2);
808 const int stride_mask = nb31 / sizeof(half);
809
810 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
811
812 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
813 constexpr int cpy_ne = cpy_nb / 4;
814
815 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
816 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
817 static_assert(cpw == 1 || np == 1, "bad cpw / np");
818 static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
819
820 constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
821 constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
822
823 // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
824 // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
825 // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
826 // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
827 // VKQ == Accumulators in registers for the final VKQ result.
828#ifdef FAST_FP16_AVAILABLE
829 __shared__ half2 Q_tmp[ncols * DKQ/2];
830 __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
831 __shared__ half KQ[ncols * nbatch_fa];
832 __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
833#else
834 __shared__ float Q_tmp[ncols * DKQ];
835 __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
836 __shared__ float KQ[ncols * nbatch_fa];
837 __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
838#endif // FAST_FP16_AVAILABLE
839
840 float KQ_max[cpw];
841#pragma unroll
842 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
843 KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
844 }
845 float KQ_sum[cpw] = {0.0f};
846
847 // Load Q data, convert to FP16 if fast:
848#pragma unroll
849 for (int jc0 = 0; jc0 < cpw; ++jc0) {
850 const int jc = jc0 + (threadIdx.y / np)*cpw;
851
852 const int j = jc / ncols2;
853 const int c = jc % ncols2;
854
855 constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
856
857#pragma unroll
858 for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
859 if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
860 __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
861 ggml_cuda_memcpy_1<sizeof(tmp_f)>
862 (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
863 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
864
865#pragma unroll
866 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
867 tmp_f[i1] *= scale;
868 }
869
870#ifdef FAST_FP16_AVAILABLE
871 __align__(16) half2 tmp_h2[cpy_ne_D/2];
872#pragma unroll
873 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
874 tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
875#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
876 // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
877 // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
878 tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
879#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
880 }
881 ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
882 &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
883 tmp_h2);
884#else
885 ggml_cuda_memcpy_1<sizeof(tmp_f)>(
886 &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D],
887 tmp_f);
888#endif // FAST_FP16_AVAILABLE
889 }
890 }
891 }
892
893 __syncthreads();
894
895 // Main loop over KV cache:
896 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
897 if (ncols2 == 1) {
898 // Branch with out-of-bounds checks.
899 int k_VKQ_0 = blockIdx.y*nbatch_fa;
900 while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
901 constexpr bool oob_check = false;
902 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
903 (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
904 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
905 k_VKQ_0 += gridDim.y*nbatch_fa;
906 }
907 if (k_VKQ_0 < k_VKQ_max) {
908 constexpr bool oob_check = true;
909 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
910 (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
911 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
912 }
913 } else {
914 // Branch without out-of-bounds checks.
915 for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
916 constexpr bool oob_check = false;
917 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
918 (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
919 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
920 }
921 }
922
923#pragma unroll
924 for (int jc0 = 0; jc0 < cpw; ++jc0) {
925 KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
926 }
927
928 if constexpr (np > 1) {
929 static_assert(cpw == 1, "bad cpw");
930 static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
931
932#ifdef FAST_FP16_AVAILABLE
933 half2 * VKQ_combine = (half2 *) KV_tmp;
934#else
935 float * VKQ_combine = (float *) KV_tmp;
936#endif // FAST_FP16_AVAILABLE
937 float * KQ_sum_combine = (float *) Q_tmp;
938
939 if (threadIdx.y % np != 0) {
940#ifdef FAST_FP16_AVAILABLE
941 constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
942#pragma unroll
943 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
944 ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
945 }
946#else
947 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
948#pragma unroll
949 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
950 ggml_cuda_memcpy_1<cpy_ne_D*4>(
951 &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
952 }
953#endif // FAST_FP16_AVAILABLE
954
955 if (threadIdx.x == 0) {
956 KQ_sum_combine[threadIdx.y] = KQ_sum[0];
957 }
958
959 return;
960 }
961
962 __syncthreads();
963
964#pragma unroll
965 for (int ip = 1; ip < np; ++ip) {
966#ifdef FAST_FP16_AVAILABLE
967 constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
968#pragma unroll
969 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
970 __align__(16) half2 tmp[cpy_ne_D];
971 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
972#pragma unroll
973 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
974 VKQ[i0/warp_size + i1] += tmp[i1];
975 }
976 }
977#else
978 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
979#pragma unroll
980 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
981 __align__(16) float tmp[cpy_ne_D];
982 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
983#pragma unroll
984 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
985 ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
986 }
987 }
988#endif // FAST_FP16_AVAILABLE
989
990 KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
991 }
992 }
993
994 // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
995 if (sinks && blockIdx.y == 0) {
996#pragma unroll
997 for (int jc0 = 0; jc0 < cpw; ++jc0) {
998 const int jc = jc0 + (threadIdx.y/np)*cpw;
999 const float sink = ((const float *) sinks)[head0 + jc % ncols2];
1000
1001 float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
1002 const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
1003 KQ_max[jc0] = KQ_max_new_j;
1004
1005 const float val = expf(sink - KQ_max[jc0]);
1006 KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
1007
1008#ifdef FAST_FP16_AVAILABLE
1009 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
1010#pragma unroll
1011 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
1012 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
1013 }
1014#else
1015#pragma unroll
1016 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
1017 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
1018 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
1019 }
1020#endif // FAST_FP16_AVAILABLE
1021 }
1022 }
1023
1024 // Write back results:
1025#pragma unroll
1026 for (int jc0 = 0; jc0 < cpw; ++jc0) {
1027 const int jc = jc0 + (threadIdx.y/np)*cpw;
1028
1029 const int j = jc / ncols2;
1030 const int c = jc % ncols2;
1031
1032 if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
1033 return;
1034 }
1035
1036 const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
1037
1038 const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1039
1040#ifdef FAST_FP16_AVAILABLE
1041 constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
1042#pragma unroll
1043 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
1044 __align__(16) float2 tmp[cpy_ne_D];
1045#pragma unroll
1046 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
1047 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1048 tmp[i1].x *= scale;
1049 tmp[i1].y *= scale;
1050 }
1051 if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
1052 ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
1053 }
1054 }
1055#else
1056 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
1057#pragma unroll
1058 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
1059 if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1060#pragma unroll
1061 for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1062 VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1063 VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1064 }
1065 ggml_cuda_memcpy_1<cpy_ne_D*4>(
1066 &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
1067 &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
1068 }
1069 }
1070#endif // FAST_FP16_AVAILABLE
1071
1072 if (gridDim.y != 1 && threadIdx.x == 0) {
1073 dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
1074 }
1075 }
1076#else
1077 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1078 max_bias, m0, m1, n_head_log2, logit_softcap,
1079 ne00, ne01, ne02, ne03,
1080 nb01, nb02, nb03,
1081 ne10, ne11, ne12, ne13,
1082 nb11, nb12, nb13,
1083 nb21, nb22, nb23,
1084 ne31, ne32, ne33,
1085 nb31, nb32, nb33);
1086 NO_DEVICE_CODE;
1087#endif // FLASH_ATTN_AVAILABLE
1088}
1089
1090template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
1091static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1092 const ggml_tensor * Q = dst->src[0];
1093
1094 const int id = ggml_cuda_get_device();
1095 const int cc = ggml_cuda_info().devices[id].cc;
1096 const int warp_size = 32;
1097
1098 constexpr size_t nbytes_shared = 0;
1099
1100#ifdef GGML_USE_HIP
1101 if constexpr (DV <= 128) {
1102 if (Q->ne[1] > 32/ncols2) {
1103 constexpr int cols_per_block = 64;
1104 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1105 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1106 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1107 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1108 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1109 return;
1110 }
1111 }
1112#endif // GGML_USE_HIP
1113
1114#ifndef GGML_USE_HIP
1115 if constexpr (DV <= 256)
1116#endif // GGML_USE_HIP
1117 {
1118 if (Q->ne[1] > 16/ncols2) {
1119 constexpr int cols_per_block = 32;
1120 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1121 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1122 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1123 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1124 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1125 return;
1126 }
1127 }
1128
1129 if (Q->ne[1] > 8/ncols2) {
1130 constexpr int cols_per_block = 16;
1131 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1132 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1133 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1134 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1135 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1136 return;
1137 }
1138
1139 if constexpr (ncols2 <= 8) {
1140 if (Q->ne[1] > 4/ncols2) {
1141 constexpr int cols_per_block = 8;
1142 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1143 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1144 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1145 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1146 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1147 return;
1148 }
1149 }
1150
1151 if constexpr (ncols2 <= 4) {
1152 if (Q->ne[1] > 2/ncols2) {
1153 constexpr int cols_per_block = 4;
1154 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1155 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1156 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1157 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1158 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1159 return;
1160 }
1161 }
1162
1163 if constexpr (ncols2 <= 2) {
1164 constexpr int cols_per_block = 2;
1165 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1166 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1167 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1168 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1169 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1170 return;
1171 }
1172
1173 GGML_ABORT("fatal error");
1174}
1175
1176template <int DKQ, int DV, bool use_logit_softcap>
1177static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1178 const ggml_tensor * KQV = dst;
1179 const ggml_tensor * Q = dst->src[0];
1180 const ggml_tensor * K = dst->src[1];
1181 const ggml_tensor * mask = dst->src[3];
1182
1183 float max_bias = 0.0f;
1184 memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
1185
1186 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
1187 const int gqa_ratio = Q->ne[2] / K->ne[2];
1188
1189 const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
1190 const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
1191 const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
1192
1193 if constexpr (DV == 512) {
1194 if (use_gqa_opt && gqa_ratio % 16 == 0) {
1195 launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1196 return;
1197 }
1198 if (use_gqa_opt && gqa_ratio % 4 == 0) {
1199 launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1200 return;
1201 }
1202 }
1203
1204 if constexpr (DV <= 256) {
1205 if (use_gqa_opt && gqa_ratio % 8 == 0) {
1206 launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
1207 return;
1208 }
1209
1210 if (use_gqa_opt && gqa_ratio % 4 == 0) {
1211 launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1212 return;
1213 }
1214
1215 if (use_gqa_opt && gqa_ratio % 2 == 0) {
1216 launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
1217 return;
1218 }
1219
1220 launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
1221 return;
1222 }
1223 GGML_ABORT("fatal error");
1224}
1225
1226template <int DKQ, int DV>
1227void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1228 const ggml_tensor * KQV = dst;
1229
1230 float logit_softcap;
1231 memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1232
1233 if (logit_softcap == 0.0f) {
1234 constexpr bool use_logit_softcap = false;
1235 launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1236 } else {
1237 constexpr bool use_logit_softcap = true;
1238 launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1239 }
1240}
1241
1242void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1243
1244#define DECL_FATTN_TILE_CASE(DKQ, DV) \
1245 template void ggml_cuda_flash_attn_ext_tile_case \
1246 <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1247
1248extern DECL_FATTN_TILE_CASE( 40, 40);
1249extern DECL_FATTN_TILE_CASE( 64, 64);
1250extern DECL_FATTN_TILE_CASE( 72, 72);
1251extern DECL_FATTN_TILE_CASE( 80, 80);
1252extern DECL_FATTN_TILE_CASE( 96, 96);
1253extern DECL_FATTN_TILE_CASE(112, 112);
1254extern DECL_FATTN_TILE_CASE(128, 128);
1255extern DECL_FATTN_TILE_CASE(256, 256);
1256extern DECL_FATTN_TILE_CASE(576, 512);