1#include "argsort.cuh"
 2#include "top-k.cuh"
 3
 4#ifdef GGML_CUDA_USE_CUB
 5#    include <cub/cub.cuh>
 6#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
 7#        define CUB_TOP_K_AVAILABLE
 8using namespace cub;
 9#    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
10#endif      // GGML_CUDA_USE_CUB
11
12#ifdef CUB_TOP_K_AVAILABLE
13
14static void top_k_cub(ggml_cuda_pool & pool,
15                      const float *    src,
16                      int *            dst,
17                      const int        ncols,
18                      const int        k,
19                      cudaStream_t     stream) {
20    auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
21                                                 cuda::execution::output_ordering::unsorted);
22    auto stream_env   = cuda::stream_ref{ stream };
23    auto env          = cuda::std::execution::env{ stream_env, requirements };
24
25    auto indexes_in = cuda::make_counting_iterator(0);
26
27    size_t temp_storage_bytes = 0;
28    DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
29                         env);
30
31    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
32    void *                        d_temp_storage = temp_storage_alloc.get();
33
34    DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
35                         ncols, k, env);
36}
37
38#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
39
40static int next_power_of_2(int x) {
41    int n = 1;
42    while (n < x) {
43        n *= 2;
44    }
45    return n;
46}
47
48#endif                            // CUB_TOP_K_AVAILABLE
49
50void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
51    const ggml_tensor * src0   = dst->src[0];
52    const float *       src0_d = (const float *) src0->data;
53    int *               dst_d  = (int *) dst->data;
54    cudaStream_t        stream = ctx.stream();
55
56    // are these asserts truly necessary?
57    GGML_ASSERT(src0->type == GGML_TYPE_F32);
58    GGML_ASSERT(dst->type == GGML_TYPE_I32);
59    GGML_ASSERT(ggml_is_contiguous(src0));
60
61    const int64_t    ncols = src0->ne[0];
62    const int64_t    nrows = ggml_nrows(src0);
63    const int64_t    k     = dst->ne[0];
64    ggml_cuda_pool & pool  = ctx.pool();
65#ifdef CUB_TOP_K_AVAILABLE
66    // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
67    // https://github.com/NVIDIA/cccl/issues/6391
68    // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
69    for (int i = 0; i < nrows; i++) {
70        top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
71    }
72#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
73    // Fall back to argsort + copy
74    const int    ncols_pad      = next_power_of_2(ncols);
75    const size_t shared_mem     = ncols_pad * sizeof(int);
76    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
77
78    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
79    int *                     tmp_dst = temp_dst_alloc.get();
80
81    if (shared_mem > max_shared_mem || ncols > 1024) {
82        argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
83    } else {
84        argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
85    }
86    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
87                                 cudaMemcpyDeviceToDevice, stream));
88#else                             // GGML_CUDA_USE_CUB
89    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
90    int *                     tmp_dst = temp_dst_alloc.get();
91    argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
92    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
93                                 cudaMemcpyDeviceToDevice, stream));
94#endif
95}