1@group(0) @binding(0)
2var<storage, read_write> src: array<f32>;
3
4@group(0) @binding(1)
5var<storage, read_write> dst: array<i32>;
6
7struct Params {
8 offset_src: u32, // in elements
9 offset_dst: u32, // in elements
10
11 stride_src1: u32,
12 stride_src2: u32,
13 stride_src3: u32,
14
15 stride_dst1: u32,
16 stride_dst2: u32,
17 stride_dst3: u32,
18
19 // src/dst dimensions
20 src_ne0: u32,
21 ne1: u32,
22 ne2: u32,
23
24 ne0: u32,
25 top_k: u32,
26
27 npr: u32, // tiles per row
28 nrows: u32
29};
30
31@group(0) @binding(2)
32var<uniform> params: Params;
33
34var<workgroup> shmem_idx: array<u32, WG_SIZE>;
35
36#if ORDER == 0
37#define EXTREME_VALUE 1e30
38#define SWAP_COMPARE_UP >
39#define SWAP_COMPARE_DOWN <
40#else
41#define EXTREME_VALUE -1e30
42#define SWAP_COMPARE_UP <
43#define SWAP_COMPARE_DOWN >
44#endif
45
46@compute @workgroup_size(WG_SIZE)
47fn main(@builtin(workgroup_id) wid: vec3<u32>,
48 @builtin(num_workgroups) num_wg: vec3<u32>,
49 @builtin(local_invocation_id) lid: vec3<u32>) {
50 let linear = wid.x + wid.y * num_wg.x;
51 // guard against overprovisioned workgroups
52 if (linear >= params.npr * params.nrows) {
53 return;
54 }
55 let tile = linear % params.npr;
56 var row = linear / params.npr;
57 let i3 = row / (params.ne2 * params.ne1);
58 row = row % (params.ne2 * params.ne1);
59 let i2 = row / params.ne1;
60 let i1 = row % params.ne1;
61
62 let row_base = params.offset_src +
63 i1 * params.stride_src1 +
64 i2 * params.stride_src2 +
65 i3 * params.stride_src3;
66
67 let tile_base = tile * WG_SIZE;
68 let idx = tile_base + lid.x;
69 shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
70 workgroupBarrier();
71
72 var k = 2u;
73 while (k <= WG_SIZE) {
74 var j = k >> 1;
75 while (j > 0) {
76 let ixj = lid.x ^ j;
77 if (ixj > lid.x) {
78 let dir_up = (lid.x & k) == 0;
79 let a_idx = shmem_idx[lid.x];
80 let b_idx = shmem_idx[ixj];
81 let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
82 let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
83 let should_swap = select(
84 (a_val SWAP_COMPARE_DOWN b_val),
85 (a_val SWAP_COMPARE_UP b_val),
86 dir_up);
87 if (should_swap) {
88 shmem_idx[lid.x] = b_idx;
89 shmem_idx[ixj] = a_idx;
90 }
91 }
92 workgroupBarrier();
93 j >>= 1;
94 }
95 k <<= 1;
96 }
97
98 let out_idx = tile * params.top_k + lid.x;
99 if (out_idx < params.ne0 && lid.x < params.top_k) {
100 let row_dst = params.offset_dst +
101 i1 * params.stride_dst1 +
102 i2 * params.stride_dst2 +
103 i3 * params.stride_dst3;
104 dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
105 }
106}