diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh | 586 |
1 files changed, 586 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh new file mode 100644 index 0000000..3f4a78c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh | |||
| @@ -0,0 +1,586 @@ | |||
| 1 | #include "common.cuh" | ||
| 2 | #include "fattn-common.cuh" | ||
| 3 | |||
| 4 | static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) { | ||
| 5 | return 128; | ||
| 6 | GGML_UNUSED(cc); | ||
| 7 | } | ||
| 8 | |||
| 9 | static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { | ||
| 10 | return 128; | ||
| 11 | } | ||
| 12 | |||
| 13 | // Currenlty llvm with the amdgcn target does not support unrolling loops | ||
| 14 | // that contain a break that can not be resolved at compile time. | ||
| 15 | #ifdef __clang__ | ||
| 16 | #pragma clang diagnostic push | ||
| 17 | #pragma clang diagnostic ignored "-Wpass-failed" | ||
| 18 | #endif // __clang__ | ||
| 19 | template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size | ||
| 20 | __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) | ||
| 21 | static __global__ void flash_attn_ext_vec( | ||
| 22 | const char * __restrict__ Q, | ||
| 23 | const char * __restrict__ K, | ||
| 24 | const char * __restrict__ V, | ||
| 25 | const char * __restrict__ mask, | ||
| 26 | const char * __restrict__ sinks, | ||
| 27 | const int * __restrict__ KV_max, | ||
| 28 | float * __restrict__ dst, | ||
| 29 | float2 * __restrict__ dst_meta, | ||
| 30 | const float scale, | ||
| 31 | const float max_bias, | ||
| 32 | const float m0, | ||
| 33 | const float m1, | ||
| 34 | const uint32_t n_head_log2, | ||
| 35 | const float logit_softcap, | ||
| 36 | const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, | ||
| 37 | const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||
| 38 | const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||
| 39 | const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||
| 40 | const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||
| 41 | const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||
| 42 | const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||
| 43 | #ifdef FLASH_ATTN_AVAILABLE | ||
| 44 | |||
| 45 | // Skip unused kernel variants for faster compilation: | ||
| 46 | if (use_logit_softcap && !(D == 128 || D == 256)) { | ||
| 47 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, | ||
| 48 | max_bias, m0, m1, n_head_log2, logit_softcap, | ||
| 49 | ne00, ne01, ne02, ne03, | ||
| 50 | nb01, nb02, nb03, | ||
| 51 | ne10, ne11, ne12, ne13, | ||
| 52 | nb11, nb12, nb13, | ||
| 53 | nb21, nb22, nb23, | ||
| 54 | ne31, ne32, ne33, | ||
| 55 | nb31, nb32, nb33); | ||
| 56 | NO_DEVICE_CODE; | ||
| 57 | return; | ||
| 58 | } | ||
| 59 | |||
| 60 | //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||
| 61 | |||
| 62 | constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); | ||
| 63 | constexpr int cpy_ne = cpy_nb / 4; | ||
| 64 | |||
| 65 | #ifdef GGML_USE_HIP | ||
| 66 | #ifdef RDNA | ||
| 67 | constexpr int nthreads_KQ_q = 2; | ||
| 68 | #else | ||
| 69 | constexpr int nthreads_KQ_q = 4; | ||
| 70 | #endif // RDNA | ||
| 71 | constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); | ||
| 72 | #else | ||
| 73 | constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32); | ||
| 74 | constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); | ||
| 75 | #endif // GGML_USE_HIP | ||
| 76 | |||
| 77 | constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); | ||
| 78 | constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; | ||
| 79 | constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; | ||
| 80 | |||
| 81 | static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); | ||
| 82 | static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); | ||
| 83 | |||
| 84 | constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; | ||
| 85 | constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; | ||
| 86 | |||
| 87 | constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); | ||
| 88 | constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; | ||
| 89 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 90 | constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); | ||
| 91 | #else | ||
| 92 | constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>(); | ||
| 93 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 94 | |||
| 95 | const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. | ||
| 96 | |||
| 97 | const int sequence = blockIdx.z / ne02; | ||
| 98 | const int head = blockIdx.z - sequence*ne02; | ||
| 99 | const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||
| 100 | Q += nb03*sequence + nb02* head + nb01*ic0; | ||
| 101 | K += nb13*sequence + nb12*(head / gqa_ratio); | ||
| 102 | V += nb23*sequence + nb22*(head / gqa_ratio); | ||
| 103 | |||
| 104 | const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); | ||
| 105 | |||
| 106 | const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); | ||
| 107 | |||
| 108 | static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); | ||
| 109 | constexpr int nwarps = nthreads / WARP_SIZE; | ||
| 110 | const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; | ||
| 111 | __builtin_assume(tid < nthreads); | ||
| 112 | |||
| 113 | constexpr int ne_KQ = ncols*D; | ||
| 114 | constexpr int ne_combine = nwarps*V_cols_per_iter*D; | ||
| 115 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 116 | half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; | ||
| 117 | __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; | ||
| 118 | #else | ||
| 119 | float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; | ||
| 120 | __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; | ||
| 121 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 122 | |||
| 123 | float KQ_max[ncols]; | ||
| 124 | float KQ_sum[ncols]; | ||
| 125 | #pragma unroll | ||
| 126 | for (int j = 0; j < ncols; ++j) { | ||
| 127 | KQ_max[j] = -FLT_MAX/2.0f; | ||
| 128 | KQ_sum[j] = 0.0f; | ||
| 129 | } | ||
| 130 | |||
| 131 | // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: | ||
| 132 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 133 | half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. | ||
| 134 | #else | ||
| 135 | __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. | ||
| 136 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 137 | int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; | ||
| 138 | float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; | ||
| 139 | if constexpr (Q_q8_1) { | ||
| 140 | #pragma unroll | ||
| 141 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||
| 142 | const int j = j0 + threadIdx.y; | ||
| 143 | |||
| 144 | if (j0 + nwarps > ncols && j >= ncols) { | ||
| 145 | break; | ||
| 146 | } | ||
| 147 | |||
| 148 | // Reuse KQ as temporary storage for converting Q to q8_1: | ||
| 149 | int * tmp_q_i32 = (int *) &KQ[j*D]; | ||
| 150 | float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); | ||
| 151 | |||
| 152 | // Set memory to zero if out of bounds: | ||
| 153 | if (ncols > 1 && ic0 + j >= int(ne01.z)) { | ||
| 154 | #pragma unroll | ||
| 155 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { | ||
| 156 | const int i = i0 + threadIdx.x; | ||
| 157 | |||
| 158 | if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { | ||
| 159 | tmp_q_i32[i] = 0; | ||
| 160 | } | ||
| 161 | } | ||
| 162 | if (threadIdx.x < D/QK8_1) { | ||
| 163 | tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); | ||
| 164 | } | ||
| 165 | } else { | ||
| 166 | const float * Q_f = (const float *) (Q + j*nb01); | ||
| 167 | constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE; | ||
| 168 | #pragma unroll | ||
| 169 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { | ||
| 170 | quantize_q8_1_to_shared<float2, nthreads_quantize> | ||
| 171 | (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); | ||
| 172 | } | ||
| 173 | } | ||
| 174 | } | ||
| 175 | |||
| 176 | __syncthreads(); | ||
| 177 | |||
| 178 | #pragma unroll | ||
| 179 | for (int j = 0; j < ncols; ++j) { | ||
| 180 | int * tmp_q_i32 = (int *) &KQ[j*D]; | ||
| 181 | float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); | ||
| 182 | |||
| 183 | #pragma unroll | ||
| 184 | for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { | ||
| 185 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ); | ||
| 186 | |||
| 187 | Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; | ||
| 188 | Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; | ||
| 189 | } | ||
| 190 | } | ||
| 191 | |||
| 192 | __syncthreads(); | ||
| 193 | } else { | ||
| 194 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 195 | const half2 scale_h2 = make_half2(scale, scale); | ||
| 196 | #pragma unroll | ||
| 197 | for (int j = 0; j < ncols; ++j) { | ||
| 198 | const float2 * Q_j = (const float2 *) (Q + j*nb01); | ||
| 199 | #pragma unroll | ||
| 200 | for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { | ||
| 201 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; | ||
| 202 | |||
| 203 | __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; | ||
| 204 | if (ncols == 1 || ic0 + j < int(ne01.z)) { | ||
| 205 | ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]); | ||
| 206 | ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); | ||
| 207 | } | ||
| 208 | #pragma unroll | ||
| 209 | for (int i1 = 0; i1 < cpy_ne; ++i1) { | ||
| 210 | Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y); | ||
| 211 | } | ||
| 212 | } | ||
| 213 | #pragma unroll | ||
| 214 | for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { | ||
| 215 | Q_reg[j][k] *= scale_h2; | ||
| 216 | } | ||
| 217 | } | ||
| 218 | #else | ||
| 219 | #pragma unroll | ||
| 220 | for (int j = 0; j < ncols; ++j) { | ||
| 221 | const float2 * Q_j = (const float2 *) (Q + j*nb01); | ||
| 222 | #pragma unroll | ||
| 223 | for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { | ||
| 224 | const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; | ||
| 225 | if (ncols == 1 || ic0 + j < int(ne01.z)) { | ||
| 226 | ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); | ||
| 227 | ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); | ||
| 228 | } | ||
| 229 | } | ||
| 230 | #pragma unroll | ||
| 231 | for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { | ||
| 232 | Q_reg[j][k].x *= scale; | ||
| 233 | Q_reg[j][k].y *= scale; | ||
| 234 | } | ||
| 235 | } | ||
| 236 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 237 | } | ||
| 238 | |||
| 239 | const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||
| 240 | K += blockIdx.y*nthreads * nb11; | ||
| 241 | V += blockIdx.y*nthreads * nb21; | ||
| 242 | maskh += blockIdx.y*nthreads; | ||
| 243 | for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads, | ||
| 244 | // Increment pointers after each loop: | ||
| 245 | K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) { | ||
| 246 | |||
| 247 | // Calculate KQ tile and keep track of new maximum KQ values: | ||
| 248 | float KQ_reg[ncols]; // KQ in registers. | ||
| 249 | |||
| 250 | float KQ_max_new[ncols]; | ||
| 251 | #pragma unroll | ||
| 252 | for (int j = 0; j < ncols; ++j) { | ||
| 253 | KQ_max_new[j] = KQ_max[j]; | ||
| 254 | } | ||
| 255 | |||
| 256 | #pragma unroll | ||
| 257 | for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { | ||
| 258 | const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0; | ||
| 259 | |||
| 260 | #pragma unroll | ||
| 261 | for (int j = 0; j < ncols; ++j) { | ||
| 262 | float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); | ||
| 263 | sum = warp_reduce_sum<nthreads_KQ>(sum); | ||
| 264 | |||
| 265 | if (use_logit_softcap) { | ||
| 266 | sum = logit_softcap*tanhf(sum); | ||
| 267 | } | ||
| 268 | |||
| 269 | if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { | ||
| 270 | sum += slope*__half2float(maskh[j*ne11 + i_KQ]); | ||
| 271 | } | ||
| 272 | |||
| 273 | KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET); | ||
| 274 | |||
| 275 | if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { | ||
| 276 | KQ_reg[j] = sum; | ||
| 277 | } | ||
| 278 | } | ||
| 279 | } | ||
| 280 | |||
| 281 | #pragma unroll | ||
| 282 | for (int j = 0; j < ncols; ++j) { | ||
| 283 | #pragma unroll | ||
| 284 | for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { | ||
| 285 | KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); | ||
| 286 | } | ||
| 287 | const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); | ||
| 288 | KQ_max[j] = KQ_max_new[j]; | ||
| 289 | |||
| 290 | KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); | ||
| 291 | KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; | ||
| 292 | KQ[j*nthreads + tid] = KQ_reg[j]; | ||
| 293 | |||
| 294 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 295 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); | ||
| 296 | #pragma unroll | ||
| 297 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 298 | VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; | ||
| 299 | } | ||
| 300 | #else | ||
| 301 | #pragma unroll | ||
| 302 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 303 | VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; | ||
| 304 | VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; | ||
| 305 | } | ||
| 306 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 307 | } | ||
| 308 | |||
| 309 | #ifndef GGML_USE_HIP | ||
| 310 | __syncwarp(); | ||
| 311 | #endif // GGML_USE_HIP | ||
| 312 | |||
| 313 | #pragma unroll | ||
| 314 | for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { | ||
| 315 | const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); | ||
| 316 | |||
| 317 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 318 | half2 KQ_k[ncols]; | ||
| 319 | #pragma unroll | ||
| 320 | for (int j = 0; j < ncols; ++j) { | ||
| 321 | KQ_k[j] = __half2half2(KQ[j*nthreads + k]); | ||
| 322 | } | ||
| 323 | #pragma unroll | ||
| 324 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { | ||
| 325 | half2 tmp[V_rows_per_thread/2]; | ||
| 326 | dequantize_V(V + k*nb21, tmp, | ||
| 327 | 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); | ||
| 328 | #pragma unroll | ||
| 329 | for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { | ||
| 330 | #pragma unroll | ||
| 331 | for (int j = 0; j < ncols; ++j) { | ||
| 332 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; | ||
| 333 | } | ||
| 334 | } | ||
| 335 | } | ||
| 336 | #else | ||
| 337 | float KQ_k[ncols]; | ||
| 338 | #pragma unroll | ||
| 339 | for (int j = 0; j < ncols; ++j) { | ||
| 340 | KQ_k[j] = KQ[j*nthreads + k]; | ||
| 341 | } | ||
| 342 | #pragma unroll | ||
| 343 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { | ||
| 344 | float2 tmp[V_rows_per_thread/2]; | ||
| 345 | dequantize_V(V + k*nb21, tmp, | ||
| 346 | 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); | ||
| 347 | #pragma unroll | ||
| 348 | for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { | ||
| 349 | #pragma unroll | ||
| 350 | for (int j = 0; j < ncols; ++j) { | ||
| 351 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; | ||
| 352 | VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; | ||
| 353 | } | ||
| 354 | } | ||
| 355 | } | ||
| 356 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 357 | } | ||
| 358 | } | ||
| 359 | |||
| 360 | if (sinks && blockIdx.y == 0) { | ||
| 361 | const float sink = ((const float *) sinks)[head]; | ||
| 362 | |||
| 363 | #pragma unroll | ||
| 364 | for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||
| 365 | const int j = j0 + threadIdx.y; | ||
| 366 | |||
| 367 | if (j0 + nwarps > ncols && j >= ncols) { | ||
| 368 | break; | ||
| 369 | } | ||
| 370 | |||
| 371 | const float kqmax_new_j = fmaxf(sink, KQ_max[j]); | ||
| 372 | const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); | ||
| 373 | KQ_max[j] = kqmax_new_j; | ||
| 374 | |||
| 375 | KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); | ||
| 376 | |||
| 377 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 378 | const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); | ||
| 379 | #pragma unroll | ||
| 380 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 381 | VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; | ||
| 382 | } | ||
| 383 | #else | ||
| 384 | #pragma unroll | ||
| 385 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 386 | VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; | ||
| 387 | VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; | ||
| 388 | } | ||
| 389 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 390 | } | ||
| 391 | } | ||
| 392 | |||
| 393 | __shared__ float KQ_max_shared[ncols][WARP_SIZE]; | ||
| 394 | __shared__ float KQ_sum_shared[ncols][WARP_SIZE]; | ||
| 395 | #pragma unroll | ||
| 396 | for (int j = 0; j < ncols; ++j) { | ||
| 397 | if (threadIdx.y == 0) { | ||
| 398 | KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f; | ||
| 399 | KQ_sum_shared[j][threadIdx.x] = 0.0f; | ||
| 400 | } | ||
| 401 | } | ||
| 402 | |||
| 403 | __syncthreads(); | ||
| 404 | |||
| 405 | #pragma unroll | ||
| 406 | for (int j = 0; j < ncols; ++j) { | ||
| 407 | if (threadIdx.x == 0) { | ||
| 408 | KQ_max_shared[j][threadIdx.y] = KQ_max[j]; | ||
| 409 | } | ||
| 410 | } | ||
| 411 | __syncthreads(); | ||
| 412 | |||
| 413 | #pragma unroll | ||
| 414 | for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { | ||
| 415 | if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { | ||
| 416 | break; | ||
| 417 | } | ||
| 418 | |||
| 419 | float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; | ||
| 420 | kqmax_new = warp_reduce_max(kqmax_new); | ||
| 421 | const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); | ||
| 422 | KQ_max[j_VKQ] = kqmax_new; | ||
| 423 | |||
| 424 | #ifdef V_DOT2_F32_F16_AVAILABLE | ||
| 425 | half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) | ||
| 426 | + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); | ||
| 427 | |||
| 428 | const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale); | ||
| 429 | #pragma unroll | ||
| 430 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 431 | VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; | ||
| 432 | } | ||
| 433 | #pragma unroll | ||
| 434 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { | ||
| 435 | const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); | ||
| 436 | |||
| 437 | ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); | ||
| 438 | } | ||
| 439 | #else | ||
| 440 | float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) | ||
| 441 | + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); | ||
| 442 | |||
| 443 | #pragma unroll | ||
| 444 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { | ||
| 445 | VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale; | ||
| 446 | VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale; | ||
| 447 | } | ||
| 448 | #pragma unroll | ||
| 449 | for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { | ||
| 450 | const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); | ||
| 451 | |||
| 452 | ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); | ||
| 453 | ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); | ||
| 454 | } | ||
| 455 | #endif // V_DOT2_F32_F16_AVAILABLE | ||
| 456 | |||
| 457 | KQ_sum[j_VKQ] *= kqmax_scale; | ||
| 458 | KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); | ||
| 459 | if (threadIdx.x == 0) { | ||
| 460 | KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ]; | ||
| 461 | } | ||
| 462 | |||
| 463 | __syncthreads(); | ||
| 464 | |||
| 465 | if (nthreads <= D || tid < D) { | ||
| 466 | KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x]; | ||
| 467 | KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); | ||
| 468 | |||
| 469 | #pragma unroll | ||
| 470 | for (int i0 = 0; i0 < D; i0 += nthreads) { | ||
| 471 | float dst_val = 0; | ||
| 472 | #pragma unroll | ||
| 473 | for (int w = 0; w < nwarps; ++w) { | ||
| 474 | #pragma unroll | ||
| 475 | for (int v = 0; v < V_cols_per_iter; ++v) { | ||
| 476 | dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); | ||
| 477 | } | ||
| 478 | } | ||
| 479 | if (gridDim.y == 1) { | ||
| 480 | dst_val /= KQ_sum[j_VKQ]; | ||
| 481 | } | ||
| 482 | dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; | ||
| 483 | } | ||
| 484 | } | ||
| 485 | |||
| 486 | if (j_VKQ < ncols-1) { | ||
| 487 | __syncthreads(); | ||
| 488 | } | ||
| 489 | |||
| 490 | } | ||
| 491 | |||
| 492 | if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { | ||
| 493 | dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); | ||
| 494 | } | ||
| 495 | #else | ||
| 496 | GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, | ||
| 497 | max_bias, m0, m1, n_head_log2, logit_softcap, | ||
| 498 | ne00, ne01, ne02, ne03, | ||
| 499 | nb01, nb02, nb03, | ||
| 500 | ne10, ne11, ne12, ne13, | ||
| 501 | nb11, nb12, nb13, | ||
| 502 | nb21, nb22, nb23, | ||
| 503 | ne31, ne32, ne33, | ||
| 504 | nb31, nb32, nb33); | ||
| 505 | NO_DEVICE_CODE; | ||
| 506 | #endif // FLASH_ATTN_AVAILABLE | ||
| 507 | } | ||
| 508 | #ifdef __clang__ | ||
| 509 | #pragma clang diagnostic pop | ||
| 510 | #endif // __clang__ | ||
| 511 | |||
| 512 | template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | ||
| 513 | void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| 514 | const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||
| 515 | |||
| 516 | const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); | ||
| 517 | const int nwarps = nthreads / WARP_SIZE; | ||
| 518 | fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>; | ||
| 519 | const bool need_f16_K = type_K == GGML_TYPE_F16; | ||
| 520 | const bool need_f16_V = type_V == GGML_TYPE_F16; | ||
| 521 | constexpr size_t nbytes_shared = 0; | ||
| 522 | launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); | ||
| 523 | } | ||
| 524 | |||
| 525 | template <int D, ggml_type type_K, ggml_type type_V> | ||
| 526 | void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| 527 | const ggml_tensor * KQV = dst; | ||
| 528 | const ggml_tensor * Q = dst->src[0]; | ||
| 529 | |||
| 530 | float logit_softcap; | ||
| 531 | memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||
| 532 | |||
| 533 | if (Q->ne[1] == 1) { | ||
| 534 | constexpr int cols_per_block = 1; | ||
| 535 | if (logit_softcap == 0.0f) { | ||
| 536 | constexpr bool use_logit_softcap = false; | ||
| 537 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); | ||
| 538 | } else { | ||
| 539 | constexpr bool use_logit_softcap = true; | ||
| 540 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); | ||
| 541 | } | ||
| 542 | return; | ||
| 543 | } | ||
| 544 | |||
| 545 | constexpr int cols_per_block = 2; | ||
| 546 | if (logit_softcap == 0.0f) { | ||
| 547 | constexpr bool use_logit_softcap = false; | ||
| 548 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); | ||
| 549 | } else { | ||
| 550 | constexpr bool use_logit_softcap = true; | ||
| 551 | ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); | ||
| 552 | } | ||
| 553 | } | ||
| 554 | |||
| 555 | #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ | ||
| 556 | template void ggml_cuda_flash_attn_ext_vec_case \ | ||
| 557 | <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ | ||
| 558 | |||
| 559 | #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ | ||
| 560 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ | ||
| 561 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ | ||
| 562 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ | ||
| 563 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ | ||
| 564 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ | ||
| 565 | extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ | ||
| 566 | |||
| 567 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) | ||
| 568 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) | ||
| 569 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) | ||
| 570 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) | ||
| 571 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) | ||
| 572 | EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) | ||
| 573 | |||
| 574 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) | ||
| 575 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) | ||
| 576 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) | ||
| 577 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) | ||
| 578 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) | ||
| 579 | EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) | ||
| 580 | |||
| 581 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) | ||
| 582 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) | ||
| 583 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) | ||
| 584 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) | ||
| 585 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) | ||
| 586 | EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) | ||
