1@group(0) @binding(0)
  2var<storage, read_write> src: array<f32>;
  3
  4@group(0) @binding(1)
  5var<storage, read_write> dst: array<i32>;
  6
  7struct Params {
  8    offset_src: u32, // in elements
  9    offset_dst: u32, // in elements
 10
 11    stride_src1: u32,
 12    stride_src2: u32,
 13    stride_src3: u32,
 14
 15    stride_dst1: u32,
 16    stride_dst2: u32,
 17    stride_dst3: u32,
 18
 19    // src/dst dimensions
 20    src_ne0: u32,
 21    ne1: u32,
 22    ne2: u32,
 23
 24    ne0: u32,
 25    top_k: u32,
 26
 27    npr: u32,   // tiles per row
 28    nrows: u32
 29};
 30
 31@group(0) @binding(2)
 32var<uniform> params: Params;
 33
 34var<workgroup> shmem_idx: array<u32, WG_SIZE>;
 35
 36#if ORDER == 0
 37#define EXTREME_VALUE 1e30
 38#define SWAP_COMPARE_UP >
 39#define SWAP_COMPARE_DOWN <
 40#else
 41#define EXTREME_VALUE -1e30
 42#define SWAP_COMPARE_UP <
 43#define SWAP_COMPARE_DOWN >
 44#endif
 45
 46@compute @workgroup_size(WG_SIZE)
 47fn main(@builtin(workgroup_id) wid: vec3<u32>,
 48        @builtin(num_workgroups) num_wg: vec3<u32>,
 49        @builtin(local_invocation_id) lid: vec3<u32>) {
 50    let linear = wid.x + wid.y * num_wg.x;
 51    // guard against overprovisioned workgroups
 52    if (linear >= params.npr * params.nrows) {
 53        return;
 54    }
 55    let tile = linear % params.npr;
 56    var row = linear / params.npr;
 57    let i3 = row / (params.ne2 * params.ne1);
 58    row = row % (params.ne2 * params.ne1);
 59    let i2 = row / params.ne1;
 60    let i1 = row % params.ne1;
 61
 62    let row_base = params.offset_src +
 63        i1 * params.stride_src1 +
 64        i2 * params.stride_src2 +
 65        i3 * params.stride_src3;
 66
 67    let tile_base = tile * WG_SIZE;
 68    let idx = tile_base + lid.x;
 69    shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
 70    workgroupBarrier();
 71
 72    var k = 2u;
 73    while (k <= WG_SIZE) {
 74        var j = k >> 1;
 75        while (j > 0) {
 76            let ixj = lid.x ^ j;
 77            if (ixj > lid.x) {
 78                let dir_up = (lid.x & k) == 0;
 79                let a_idx = shmem_idx[lid.x];
 80                let b_idx = shmem_idx[ixj];
 81                let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
 82                let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
 83                let should_swap = select(
 84                    (a_val SWAP_COMPARE_DOWN b_val),
 85                    (a_val SWAP_COMPARE_UP b_val),
 86                    dir_up);
 87                if (should_swap) {
 88                    shmem_idx[lid.x] = b_idx;
 89                    shmem_idx[ixj] = a_idx;
 90                }
 91            }
 92            workgroupBarrier();
 93            j >>= 1;
 94        }
 95        k <<= 1;
 96    }
 97
 98    let out_idx = tile * params.top_k + lid.x;
 99    if (out_idx < params.ne0 && lid.x < params.top_k) {
100        let row_dst = params.offset_dst +
101            i1 * params.stride_dst1 +
102            i2 * params.stride_dst2 +
103            i3 * params.stride_dst3;
104        dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
105    }
106}