summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/argsort.cuh
blob: 22b7306f20201d81009327e8ff9f6eb0a029007e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include "common.cuh"

void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                              const float *    x,
                              int *            dst,
                              const int        ncols,
                              const int        nrows,
                              ggml_sort_order  order,
                              cudaStream_t     stream);
#endif  // GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_bitonic(const float *   x,
                                  int *           dst,
                                  const int       ncols,
                                  const int       nrows,
                                  ggml_sort_order order,
                                  cudaStream_t    stream);