1#include "argsort.cuh"
  2
  3#ifdef GGML_CUDA_USE_CUB
  4#    include <cub/cub.cuh>
  5#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
  6#        define STRIDED_ITERATOR_AVAILABLE
  7#    endif
  8using namespace cub;
  9#endif  // GGML_CUDA_USE_CUB
 10
 11static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
 12    const int col = blockIdx.x * blockDim.x + threadIdx.x;
 13    const int row = blockIdx.y;
 14
 15    if (col < ncols && row < nrows) {
 16        indices[row * ncols + col] = col;
 17    }
 18}
 19
 20#ifndef STRIDED_ITERATOR_AVAILABLE
 21static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
 22    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
 23    if (idx <= nrows) {
 24        offsets[idx] = idx * ncols;
 25    }
 26}
 27#endif  // STRIDED_ITERATOR_AVAILABLE
 28
 29#ifdef GGML_CUDA_USE_CUB
 30void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
 31                              const float *    x,
 32                              int *            dst,
 33                              const int        ncols,
 34                              const int        nrows,
 35                              ggml_sort_order  order,
 36                              cudaStream_t     stream) {
 37    ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
 38    ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
 39
 40    int *   temp_indices = temp_indices_alloc.get();
 41    float * temp_keys    = temp_keys_alloc.get();
 42
 43    static const int block_size = 256;
 44    const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
 45    init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
 46
 47#ifdef STRIDED_ITERATOR_AVAILABLE
 48    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
 49#else
 50    ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
 51    int *                     offset_iterator = offsets_alloc.get();
 52    const dim3                offset_grid((nrows + block_size - 1) / block_size);
 53    init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
 54#endif
 55    CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 56
 57    size_t temp_storage_bytes = 0;
 58
 59    if (order == GGML_SORT_ORDER_ASC) {
 60        if (nrows == 1) {
 61            DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
 62                                       temp_indices, dst,                                  // values (indices)
 63                                       ncols, 0, sizeof(float) * 8, stream);
 64        } else {
 65            DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
 66                                           temp_indices, dst,                                  // values (indices)
 67                                           ncols * nrows, nrows,  // num items, num segments
 68                                           offset_iterator, offset_iterator + 1, stream);
 69        }
 70    } else {
 71        if (nrows == 1) {
 72            DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
 73                                                 temp_indices, dst,                                  // values (indices)
 74                                                 ncols, 0, sizeof(float) * 8, stream);
 75        } else {
 76            DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
 77                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
 78                                                     stream);
 79        }
 80    }
 81
 82    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
 83    void *                        d_temp_storage = temp_storage_alloc.get();
 84
 85    if (order == GGML_SORT_ORDER_ASC) {
 86        if (nrows == 1) {
 87            DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
 88                                       temp_indices, dst,  // values (indices)
 89                                       ncols, 0, sizeof(float) * 8, stream);
 90        } else {
 91            DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
 92                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
 93        }
 94    } else {
 95        if (nrows == 1) {
 96            DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
 97                                                 temp_indices, dst,                                  // values (indices)
 98                                                 ncols, 0, sizeof(float) * 8, stream);
 99        } else {
100            DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
101                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,
102                                                     offset_iterator + 1, stream);
103        }
104    }
105}
106#endif  // GGML_CUDA_USE_CUB
107
108// Bitonic sort implementation
109template<typename T>
110static inline __device__ void ggml_cuda_swap(T & a, T & b) {
111    T tmp = a;
112    a = b;
113    b = tmp;
114}
115
116template<ggml_sort_order order>
117static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
118    // bitonic sort
119    int col = threadIdx.x;
120    int row = blockIdx.x;
121
122    if (col >= ncols_pad) {
123        return;
124    }
125
126    const float * x_row = x + row * ncols;
127    extern __shared__ int dst_row[];
128
129    // initialize indices
130    dst_row[col] = col;
131
132    __syncthreads();
133
134    for (int k = 2; k <= ncols_pad; k *= 2) {
135        for (int j = k / 2; j > 0; j /= 2) {
136            int ixj = col ^ j;
137            if (ixj > col) {
138                if ((col & k) == 0) {
139                    if (dst_row[col] >= ncols ||
140                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
141                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
142                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
143                    ) {
144                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
145                    }
146                } else {
147                    if (dst_row[ixj] >= ncols ||
148                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
149                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
150                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
151                    ) {
152                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
153                    }
154                }
155            }
156            __syncthreads();
157        }
158    }
159
160    // copy the result to dst without the padding
161    if (col < ncols) {
162        dst[row * ncols + col] = dst_row[col];
163    }
164}
165
166static int next_power_of_2(int x) {
167    int n = 1;
168    while (n < x) {
169        n *= 2;
170    }
171    return n;
172}
173
174void argsort_f32_i32_cuda_bitonic(const float *   x,
175                                  int *           dst,
176                                  const int       ncols,
177                                  const int       nrows,
178                                  ggml_sort_order order,
179                                  cudaStream_t    stream) {
180    // bitonic sort requires ncols to be power of 2
181    const int ncols_pad = next_power_of_2(ncols);
182
183    const dim3 block_dims(ncols_pad, 1, 1);
184    const dim3 block_nums(nrows, 1, 1);
185    const size_t shared_mem = ncols_pad * sizeof(int);
186
187    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
188    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
189
190    if (order == GGML_SORT_ORDER_ASC) {
191        k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
192            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
193    } else if (order == GGML_SORT_ORDER_DESC) {
194        k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
195            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
196    } else {
197        GGML_ABORT("fatal error");
198    }
199}
200
201void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202    const ggml_tensor * src0 = dst->src[0];
203    const float * src0_d = (const float *)src0->data;
204    float * dst_d = (float *)dst->data;
205    cudaStream_t stream = ctx.stream();
206
207    GGML_ASSERT(src0->type == GGML_TYPE_F32);
208    GGML_ASSERT( dst->type == GGML_TYPE_I32);
209    GGML_ASSERT(ggml_is_contiguous(src0));
210
211    const int64_t ncols = src0->ne[0];
212    const int64_t nrows = ggml_nrows(src0);
213
214    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
215
216#ifdef GGML_CUDA_USE_CUB
217    const int    ncols_pad      = next_power_of_2(ncols);
218    const size_t shared_mem     = ncols_pad * sizeof(int);
219    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
220
221    if (shared_mem > max_shared_mem || ncols > 1024) {
222        ggml_cuda_pool & pool = ctx.pool();
223        argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
224    } else {
225        argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
226    }
227#else
228    argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
229#endif
230}