diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/mmq.cu')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/mmq.cu | 366 |
1 files changed, 366 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmq.cu b/llama.cpp/ggml/src/ggml-cuda/mmq.cu new file mode 100644 index 0000000..9a69f41 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cu | |||
| @@ -0,0 +1,366 @@ | |||
| 1 | #include "common.cuh" | ||
| 2 | #include "mmq.cuh" | ||
| 3 | #include "quantize.cuh" | ||
| 4 | #include "mmid.cuh" | ||
| 5 | |||
| 6 | static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { | ||
| 7 | switch (args.type_x) { | ||
| 8 | case GGML_TYPE_Q4_0: | ||
| 9 | mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream); | ||
| 10 | break; | ||
| 11 | case GGML_TYPE_Q4_1: | ||
| 12 | mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream); | ||
| 13 | break; | ||
| 14 | case GGML_TYPE_Q5_0: | ||
| 15 | mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream); | ||
| 16 | break; | ||
| 17 | case GGML_TYPE_Q5_1: | ||
| 18 | mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream); | ||
| 19 | break; | ||
| 20 | case GGML_TYPE_Q8_0: | ||
| 21 | mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream); | ||
| 22 | break; | ||
| 23 | case GGML_TYPE_MXFP4: | ||
| 24 | mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream); | ||
| 25 | break; | ||
| 26 | case GGML_TYPE_Q2_K: | ||
| 27 | mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream); | ||
| 28 | break; | ||
| 29 | case GGML_TYPE_Q3_K: | ||
| 30 | mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream); | ||
| 31 | break; | ||
| 32 | case GGML_TYPE_Q4_K: | ||
| 33 | mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream); | ||
| 34 | break; | ||
| 35 | case GGML_TYPE_Q5_K: | ||
| 36 | mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream); | ||
| 37 | break; | ||
| 38 | case GGML_TYPE_Q6_K: | ||
| 39 | mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream); | ||
| 40 | break; | ||
| 41 | case GGML_TYPE_IQ2_XXS: | ||
| 42 | mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream); | ||
| 43 | break; | ||
| 44 | case GGML_TYPE_IQ2_XS: | ||
| 45 | mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream); | ||
| 46 | break; | ||
| 47 | case GGML_TYPE_IQ2_S: | ||
| 48 | mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream); | ||
| 49 | break; | ||
| 50 | case GGML_TYPE_IQ3_XXS: | ||
| 51 | mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream); | ||
| 52 | break; | ||
| 53 | case GGML_TYPE_IQ3_S: | ||
| 54 | mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream); | ||
| 55 | break; | ||
| 56 | case GGML_TYPE_IQ1_S: | ||
| 57 | mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream); | ||
| 58 | break; | ||
| 59 | case GGML_TYPE_IQ4_XS: | ||
| 60 | mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream); | ||
| 61 | break; | ||
| 62 | case GGML_TYPE_IQ4_NL: | ||
| 63 | mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream); | ||
| 64 | break; | ||
| 65 | default: | ||
| 66 | GGML_ABORT("fatal error"); | ||
| 67 | break; | ||
| 68 | } | ||
| 69 | } | ||
| 70 | |||
| 71 | void ggml_cuda_mul_mat_q( | ||
| 72 | ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { | ||
| 73 | GGML_ASSERT( src1->type == GGML_TYPE_F32); | ||
| 74 | GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
| 75 | GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. | ||
| 76 | |||
| 77 | GGML_TENSOR_BINARY_OP_LOCALS; | ||
| 78 | |||
| 79 | cudaStream_t stream = ctx.stream(); | ||
| 80 | const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||
| 81 | |||
| 82 | const size_t ts_src0 = ggml_type_size(src0->type); | ||
| 83 | const size_t ts_src1 = ggml_type_size(src1->type); | ||
| 84 | const size_t ts_dst = ggml_type_size(dst->type); | ||
| 85 | |||
| 86 | GGML_ASSERT( nb00 == ts_src0); | ||
| 87 | GGML_ASSERT( nb10 == ts_src1); | ||
| 88 | GGML_ASSERT( nb0 == ts_dst); | ||
| 89 | GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); | ||
| 90 | |||
| 91 | const char * src0_d = (const char *) src0->data; | ||
| 92 | const float * src1_d = (const float *) src1->data; | ||
| 93 | float * dst_d = (float *) dst->data; | ||
| 94 | |||
| 95 | // If src0 is a temporary compute buffer, clear any potential padding. | ||
| 96 | if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { | ||
| 97 | const size_t size_data = ggml_nbytes(src0); | ||
| 98 | const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); | ||
| 99 | if (size_alloc > size_data) { | ||
| 100 | GGML_ASSERT(ggml_is_contiguously_allocated(src0)); | ||
| 101 | GGML_ASSERT(!src0->view_src); | ||
| 102 | CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); | ||
| 103 | } | ||
| 104 | } | ||
| 105 | |||
| 106 | const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); | ||
| 107 | |||
| 108 | const int64_t s01 = src0->nb[1] / ts_src0; | ||
| 109 | const int64_t s1 = dst->nb[1] / ts_dst; | ||
| 110 | const int64_t s02 = src0->nb[2] / ts_src0; | ||
| 111 | const int64_t s2 = dst->nb[2] / ts_dst; | ||
| 112 | const int64_t s03 = src0->nb[3] / ts_src0; | ||
| 113 | const int64_t s3 = dst->nb[3] / ts_dst; | ||
| 114 | |||
| 115 | const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) | ||
| 116 | || GGML_CUDA_CC_IS_CDNA(cc); | ||
| 117 | |||
| 118 | // TODO: tighter pool buffer size vs q8 path | ||
| 119 | const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; | ||
| 120 | |||
| 121 | if (!ids) { | ||
| 122 | const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + | ||
| 123 | get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); | ||
| 124 | ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1); | ||
| 125 | |||
| 126 | { | ||
| 127 | const int64_t s11 = src1->nb[1] / ts_src1; | ||
| 128 | const int64_t s12 = src1->nb[2] / ts_src1; | ||
| 129 | const int64_t s13 = src1->nb[3] / ts_src1; | ||
| 130 | if (use_native_mxfp4) { | ||
| 131 | static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); | ||
| 132 | quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, | ||
| 133 | ne11, ne12, ne13, stream); | ||
| 134 | |||
| 135 | } else { | ||
| 136 | quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, | ||
| 137 | ne11, ne12, ne13, stream); | ||
| 138 | } | ||
| 139 | CUDA_CHECK(cudaGetLastError()); | ||
| 140 | } | ||
| 141 | |||
| 142 | // Stride depends on quantization format | ||
| 143 | const int64_t s12 = use_native_mxfp4 ? | ||
| 144 | ne11 * ne10_padded * sizeof(block_fp4_mmq) / | ||
| 145 | (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) | ||
| 146 | : | ||
| 147 | ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); | ||
| 148 | const int64_t s13 = ne12*s12; | ||
| 149 | |||
| 150 | const mmq_args args = { | ||
| 151 | src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, | ||
| 152 | ne00, ne01, ne1, s01, ne11, s1, | ||
| 153 | ne02, ne12, s02, s12, s2, | ||
| 154 | ne03, ne13, s03, s13, s3, | ||
| 155 | use_stream_k, ne1}; | ||
| 156 | ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||
| 157 | return; | ||
| 158 | } | ||
| 159 | |||
| 160 | GGML_ASSERT(ne13 == 1); | ||
| 161 | GGML_ASSERT(nb12 % nb11 == 0); | ||
| 162 | GGML_ASSERT(nb2 % nb1 == 0); | ||
| 163 | |||
| 164 | const int64_t n_expert_used = ids->ne[0]; | ||
| 165 | const int64_t ne_get_rows = ne12 * n_expert_used; | ||
| 166 | GGML_ASSERT(ne1 == n_expert_used); | ||
| 167 | |||
| 168 | ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows); | ||
| 169 | ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows); | ||
| 170 | ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1); | ||
| 171 | |||
| 172 | { | ||
| 173 | GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); | ||
| 174 | const int si1 = ids->nb[1] / ggml_element_size(ids); | ||
| 175 | const int sis1 = nb12 / nb11; | ||
| 176 | |||
| 177 | ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||
| 178 | ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||
| 179 | CUDA_CHECK(cudaGetLastError()); | ||
| 180 | } | ||
| 181 | |||
| 182 | const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + | ||
| 183 | get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); | ||
| 184 | ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1); | ||
| 185 | |||
| 186 | const int64_t ne11_flat = ne12*n_expert_used; | ||
| 187 | const int64_t ne12_flat = 1; | ||
| 188 | const int64_t ne13_flat = 1; | ||
| 189 | |||
| 190 | { | ||
| 191 | const int64_t s11 = src1->nb[1] / ts_src1; | ||
| 192 | const int64_t s12 = src1->nb[2] / ts_src1; | ||
| 193 | const int64_t s13 = src1->nb[3] / ts_src1; | ||
| 194 | |||
| 195 | if (use_native_mxfp4) { | ||
| 196 | quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, | ||
| 197 | ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); | ||
| 198 | } else { | ||
| 199 | quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, | ||
| 200 | ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); | ||
| 201 | } | ||
| 202 | CUDA_CHECK(cudaGetLastError()); | ||
| 203 | } | ||
| 204 | |||
| 205 | const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : | ||
| 206 | ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); | ||
| 207 | const int64_t s13 = ne12*s12; | ||
| 208 | |||
| 209 | // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. | ||
| 210 | const mmq_args args = { | ||
| 211 | src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, | ||
| 212 | ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, | ||
| 213 | ne02, ne02, s02, s12, s2, | ||
| 214 | ne03, ne13, s03, s13, s3, | ||
| 215 | use_stream_k, ne12}; | ||
| 216 | |||
| 217 | ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||
| 218 | } | ||
| 219 | |||
| 220 | void ggml_cuda_op_mul_mat_q( | ||
| 221 | ggml_backend_cuda_context & ctx, | ||
| 222 | const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, | ||
| 223 | const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, | ||
| 224 | const int64_t src1_padded_row_size, cudaStream_t stream) { | ||
| 225 | |||
| 226 | const int64_t ne00 = src0->ne[0]; | ||
| 227 | |||
| 228 | const int64_t ne10 = src1->ne[0]; | ||
| 229 | const int64_t ne11 = src1->ne[1]; | ||
| 230 | GGML_ASSERT(ne10 % QK8_1 == 0); | ||
| 231 | |||
| 232 | const int64_t ne0 = dst->ne[0]; | ||
| 233 | |||
| 234 | const int64_t row_diff = row_high - row_low; | ||
| 235 | const int64_t stride01 = ne00 / ggml_blck_size(src0->type); | ||
| 236 | |||
| 237 | const int id = ggml_cuda_get_device(); | ||
| 238 | const int cc = ggml_cuda_info().devices[id].cc; | ||
| 239 | |||
| 240 | // the main device has a larger memory buffer to hold the results from all GPUs | ||
| 241 | // nrows_dst == nrows of the matrix that the kernel writes into | ||
| 242 | const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; | ||
| 243 | |||
| 244 | // The stream-k decomposition is only faster for recent NVIDIA GPUs. | ||
| 245 | // Also its fixup needs to allocate a temporary buffer in the memory pool. | ||
| 246 | // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. | ||
| 247 | const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) | ||
| 248 | || GGML_CUDA_CC_IS_CDNA(cc)) | ||
| 249 | && src1_ncols == ne11; | ||
| 250 | const mmq_args args = { | ||
| 251 | src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, | ||
| 252 | ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, | ||
| 253 | 1, 1, 0, 0, 0, | ||
| 254 | 1, 1, 0, 0, 0, | ||
| 255 | use_stream_k, src1_ncols}; | ||
| 256 | |||
| 257 | ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||
| 258 | |||
| 259 | GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size); | ||
| 260 | } | ||
| 261 | |||
| 262 | bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) { | ||
| 263 | #ifdef GGML_CUDA_FORCE_CUBLAS | ||
| 264 | return false; | ||
| 265 | #endif // GGML_CUDA_FORCE_CUBLAS | ||
| 266 | |||
| 267 | bool mmq_supported; | ||
| 268 | |||
| 269 | switch (type) { | ||
| 270 | case GGML_TYPE_Q4_0: | ||
| 271 | case GGML_TYPE_Q4_1: | ||
| 272 | case GGML_TYPE_Q5_0: | ||
| 273 | case GGML_TYPE_Q5_1: | ||
| 274 | case GGML_TYPE_Q8_0: | ||
| 275 | case GGML_TYPE_MXFP4: | ||
| 276 | case GGML_TYPE_Q2_K: | ||
| 277 | case GGML_TYPE_Q3_K: | ||
| 278 | case GGML_TYPE_Q4_K: | ||
| 279 | case GGML_TYPE_Q5_K: | ||
| 280 | case GGML_TYPE_Q6_K: | ||
| 281 | case GGML_TYPE_IQ2_XXS: | ||
| 282 | case GGML_TYPE_IQ2_XS: | ||
| 283 | case GGML_TYPE_IQ2_S: | ||
| 284 | case GGML_TYPE_IQ3_XXS: | ||
| 285 | case GGML_TYPE_IQ3_S: | ||
| 286 | case GGML_TYPE_IQ1_S: | ||
| 287 | case GGML_TYPE_IQ4_XS: | ||
| 288 | case GGML_TYPE_IQ4_NL: | ||
| 289 | mmq_supported = true; | ||
| 290 | break; | ||
| 291 | default: | ||
| 292 | mmq_supported = false; | ||
| 293 | break; | ||
| 294 | } | ||
| 295 | |||
| 296 | if (!mmq_supported) { | ||
| 297 | return false; | ||
| 298 | } | ||
| 299 | |||
| 300 | if (turing_mma_available(cc)) { | ||
| 301 | return true; | ||
| 302 | } | ||
| 303 | |||
| 304 | if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { | ||
| 305 | return false; | ||
| 306 | } | ||
| 307 | |||
| 308 | #ifdef GGML_CUDA_FORCE_MMQ | ||
| 309 | return true; | ||
| 310 | #endif //GGML_CUDA_FORCE_MMQ | ||
| 311 | |||
| 312 | if (GGML_CUDA_CC_IS_NVIDIA(cc)) { | ||
| 313 | return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||
| 314 | } | ||
| 315 | |||
| 316 | if (amd_mfma_available(cc)) { | ||
| 317 | // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) | ||
| 318 | // performs better but is currently suffering from a crash on this architecture. | ||
| 319 | // TODO: Revisit when hipblaslt is fixed on CDNA3 | ||
| 320 | if (GGML_CUDA_CC_IS_CDNA3(cc)) { | ||
| 321 | return true; | ||
| 322 | } | ||
| 323 | if (n_experts > 64 || ne11 <= 128) { | ||
| 324 | return true; | ||
| 325 | } | ||
| 326 | if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { | ||
| 327 | return true; | ||
| 328 | } | ||
| 329 | if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { | ||
| 330 | return true; | ||
| 331 | } | ||
| 332 | return false; | ||
| 333 | } | ||
| 334 | |||
| 335 | if (amd_wmma_available(cc)) { | ||
| 336 | if (GGML_CUDA_CC_IS_RDNA3(cc)) { | ||
| 337 | // High expert counts are almost always better on MMQ due to | ||
| 338 | // the synchronization overhead in the cuBLAS/hipBLAS path: | ||
| 339 | // https://github.com/ggml-org/llama.cpp/pull/18202 | ||
| 340 | if (n_experts >= 64) { | ||
| 341 | return true; | ||
| 342 | } | ||
| 343 | |||
| 344 | // For some quantization types MMQ can have lower peak TOPS than hipBLAS | ||
| 345 | // so it's only faster for sufficiently small batch sizes: | ||
| 346 | switch (type) { | ||
| 347 | case GGML_TYPE_Q2_K: | ||
| 348 | return ne11 <= 128; | ||
| 349 | case GGML_TYPE_Q6_K: | ||
| 350 | return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256); | ||
| 351 | case GGML_TYPE_IQ2_XS: | ||
| 352 | case GGML_TYPE_IQ2_S: | ||
| 353 | return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128; | ||
| 354 | default: | ||
| 355 | return true; | ||
| 356 | } | ||
| 357 | } | ||
| 358 | |||
| 359 | // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS: | ||
| 360 | // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 | ||
| 361 | return true; | ||
| 362 | } | ||
| 363 | |||
| 364 | return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||
| 365 | |||
| 366 | } | ||
