1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; }
22
23enum ggml_sort_order {
24 GGML_SORT_ORDER_ASC,
25 GGML_SORT_ORDER_DESC,
26};
27
28kernel void kernel_argsort_f32_i32(
29 global float * src0,
30 ulong offset0,
31 global int * dst,
32 ulong offsetd,
33 const int ne00,
34 const int ne00_pad,
35 const int order,
36 local int * dst_row
37) {
38 // bitonic sort
39 int col = get_local_id(0);
40 int row = get_group_id(1);
41
42 if (col >= ne00_pad) {
43 return;
44 }
45
46 src0 = (global char *)((global char *)src0 + offset0);
47 dst = (global float *)((global char *)dst + offsetd);
48
49 global float * x_row = src0 + row * ne00;
50
51 // initialize indices
52 dst_row[col] = col;
53
54 barrier(CLK_LOCAL_MEM_FENCE);
55
56 for (int k = 2; k <= ne00_pad; k *= 2) {
57 for (int j = k / 2; j > 0; j /= 2) {
58 int ixj = col ^ j;
59 if (ixj > col) {
60 if ((col & k) == 0) {
61 if (dst_row[col] >= ne00 ||
62 (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ?
63 x_row[dst_row[col]] > x_row[dst_row[ixj]] :
64 x_row[dst_row[col]] < x_row[dst_row[ixj]]))
65 ) {
66 SWAP(dst_row[col], dst_row[ixj], int);
67 }
68 } else {
69 if (dst_row[ixj] >= ne00 ||
70 (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ?
71 x_row[dst_row[col]] < x_row[dst_row[ixj]] :
72 x_row[dst_row[col]] > x_row[dst_row[ixj]]))
73 ) {
74 SWAP(dst_row[col], dst_row[ixj], int);
75 }
76 }
77 }
78 barrier(CLK_LOCAL_MEM_FENCE);
79 }
80 }
81
82 // copy the result to dst without the padding
83 if (col < ne00) {
84 dst[row * ne00 + col] = dst_row[col];
85 }
86}