diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/sumrows.cu | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/sumrows.cu')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/sumrows.cu | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/sumrows.cu b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu new file mode 100644 index 0000000..4025771 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu @@ -0,0 +1,43 @@ +#include "reduce_rows.cuh" +#include "sumrows.cuh" + +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const dim3 block_nums(nrows, 1, 1); + if ((nrows / nsm) < 2) { + const dim3 block_dims(512, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + } else { + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + } +} + +void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_nums(nrows, 1, 1); + + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + if ((nrows / nsm) < 2) { + // Increase num threads to 512 for small nrows to better hide the latency + const dim3 block_dims(512, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } else { + // Enough active SMs to hide latency, use smaller blocks to allow better scheduling + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } +} |
