1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 2
 3#ifdef cl_intel_subgroups
 4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
 5#else
 6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
 7#endif
 8
 9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; }
22
23enum ggml_sort_order {
24    GGML_SORT_ORDER_ASC,
25    GGML_SORT_ORDER_DESC,
26};
27
28kernel void kernel_argsort_f32_i32(
29    global float * src0,
30    ulong          offset0,
31    global int   * dst,
32    ulong          offsetd,
33    const int      ne00,
34    const int      ne00_pad,
35    const int      order,
36    local int    * dst_row
37) {
38    // bitonic sort
39    int col = get_local_id(0);
40    int row = get_group_id(1);
41
42    if (col >= ne00_pad) {
43        return;
44    }
45
46    src0 = (global char  *)((global char *)src0 + offset0);
47    dst  = (global float *)((global char *)dst  + offsetd);
48
49    global float * x_row = src0 + row * ne00;
50
51    // initialize indices
52    dst_row[col] = col;
53
54    barrier(CLK_LOCAL_MEM_FENCE);
55
56    for (int k = 2; k <= ne00_pad; k *= 2) {
57        for (int j = k / 2; j > 0; j /= 2) {
58            int ixj = col ^ j;
59            if (ixj > col) {
60                if ((col & k) == 0) {
61                    if (dst_row[col] >= ne00 ||
62                        (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ?
63                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
64                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
65                    ) {
66                        SWAP(dst_row[col], dst_row[ixj], int);
67                    }
68                } else {
69                    if (dst_row[ixj] >= ne00 ||
70                        (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ?
71                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
72                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
73                    ) {
74                        SWAP(dst_row[col], dst_row[ixj], int);
75                    }
76                }
77            }
78            barrier(CLK_LOCAL_MEM_FENCE);
79        }
80    }
81
82    // copy the result to dst without the padding
83    if (col < ne00) {
84        dst[row * ne00 + col] = dst_row[col];
85    }
86}