diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/mmvq.cu | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/mmvq.cu')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/mmvq.cu | 767 |
1 files changed, 767 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvq.cu b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu new file mode 100644 index 0000000..ce25ccf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu @@ -0,0 +1,767 @@ +#include "mmvq.cuh" +#include "quantize.cuh" +#include "unary.cuh" +#include "vecdotq.cuh" + +#include <cstdint> + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } +} + +enum mmvq_parameter_table_id { + MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_GCN, + MMVQ_PARAMETERS_RDNA2 +}; + +static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) + return MMVQ_PARAMETERS_RDNA2; +#elif defined(GCN) || defined(CDNA) + return MMVQ_PARAMETERS_GCN; +#else + return MMVQ_PARAMETERS_GENERIC; +#endif +} + +static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA2; + } + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return MMVQ_PARAMETERS_GCN; + } + return MMVQ_PARAMETERS_GENERIC; +} + +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } else if (table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 2; + case 5: + case 6: + case 7: + case 8: + default: + return 1; + } + } + return 1; +} + +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + return 1; + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } + return 1; +} + +template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false> +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int qi = ggml_cuda_type_traits<type>::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr mmvq_parameter_table_id table_id = get_device_table_id(); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const int tid = warp_size*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + uint32_t token_idx = 0; + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + } + + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + + float x_biases[ncols_dst] = { 0.0f }; + float gate_biases[ncols_dst] = { 0.0f }; + if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + // 1. Hide latency by prefetching bias and gate here + // 2. load only on threads that won't die after partial sum calculation + if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && + (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x]; + } + } + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && + (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x]; + } + } + } + } + + // partial sum for each thread + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y; + } + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + + if constexpr (is_multi_token_id) { + dst += token_idx*stride_col_dst; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } + } + tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]); + } + } + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_biases[j]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_biases[j]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; + } + } + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + } +} + +static std::pair<dim3, dim3> calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + return {block_nums, block_dims}; +} + +template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false> +static void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (c_ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); +} + +template <ggml_type type> +static void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride, cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } + + switch (ncols_dst) { + case 1: { + constexpr int c_ncols_dst = 1; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 2: { + constexpr int c_ncols_dst = 2; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 3: { + constexpr int c_ncols_dst = 3; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 4: { + constexpr int c_ncols_dst = 4; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 5: { + constexpr int c_ncols_dst = 5; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 6: { + constexpr int c_ncols_dst = 6; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 7: { + constexpr int c_ncols_dst = 7; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 8: { + constexpr int c_ncols_dst = 8; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + default: + GGML_ABORT("fatal error"); + break; + } + + GGML_UNUSED(has_fusion); +} +static void mul_mat_vec_q_switch_type( + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride, cudaStream_t stream) { + switch (type_x) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1); + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + } + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = ne10_padded / QK8_1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const int64_t s12 = ne11*s11; + const int64_t s13 = ne12*s12; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + + mul_mat_vec_q_switch_type( + src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, + ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, ids_stride, stream); +} + +void ggml_cuda_op_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + const int stride_row_x = ne00 / ggml_blck_size(src0->type); + const int stride_col_y = src1_padded_row_size / QK8_1; + + ggml_cuda_mm_fusion_args_device fusion_local{}; + mul_mat_vec_q_switch_type( + src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); + + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); +} |
