1@group(0) @binding(0)
  2var<storage, read_write> src: array<f32>;
  3
  4@group(0) @binding(1)
  5var<storage, read_write> idx_in: array<i32>;
  6
  7@group(0) @binding(2)
  8var<storage, read_write> idx_out: array<i32>;
  9
 10struct Params {
 11    offset_src: u32, // in elements
 12    offset_in: u32,  // in elements
 13    offset_out: u32, // in elements
 14
 15    stride_src1: u32,
 16    stride_src2: u32,
 17    stride_src3: u32,
 18
 19    stride_idx1: u32,
 20    stride_idx2: u32,
 21    stride_idx3: u32,
 22
 23    stride_out1: u32,
 24    stride_out2: u32,
 25    stride_out3: u32,
 26
 27    ne0: u32,
 28    ne1: u32,
 29    ne2: u32,
 30
 31    top_k: u32,
 32
 33    len: u32,
 34    nm: u32,
 35    nrows: u32
 36};
 37
 38@group(0) @binding(3)
 39var<uniform> params: Params;
 40
 41fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
 42    let a_val = src[row_base + u32(a_idx)];
 43    let b_val = src[row_base + u32(b_idx)];
 44#if ORDER == 0
 45    return a_val <= b_val;
 46#else
 47    return a_val >= b_val;
 48#endif
 49}
 50
 51@compute @workgroup_size(WG_SIZE)
 52fn main(@builtin(workgroup_id) wid: vec3<u32>,
 53        @builtin(num_workgroups) num_wg: vec3<u32>,
 54        @builtin(local_invocation_id) lid: vec3<u32>) {
 55    let linear = wid.x + wid.y * num_wg.x;
 56    // guard against overprovisioned workgroups
 57    if (linear >= params.nm * params.nrows) {
 58        return;
 59    }
 60
 61    let start = (linear % params.nm) * params.len * 2;
 62    let len0 = min(params.len, params.ne0 - start);
 63    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
 64    let len1 = min(params.len, rem1);
 65    let total = len0 + len1;
 66    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
 67    let k0 = lid.x * chunk;
 68    let k1 = min(min(k0 + chunk, total), params.top_k);
 69    // guard against overprovisioned threads
 70    if (k0 >= params.top_k || k0 >= total) {
 71        return;
 72    }
 73
 74    var row = linear / params.nm;
 75    let i3 = row / (params.ne2 * params.ne1);
 76    row = row % (params.ne2 * params.ne1);
 77    let i2 = row / params.ne1;
 78    let i1 = row % params.ne1;
 79
 80    let row_src = params.offset_src +
 81        i1 * params.stride_src1 +
 82        i2 * params.stride_src2 +
 83        i3 * params.stride_src3;
 84
 85    let row_in = params.offset_in +
 86        i1 * params.stride_idx1 +
 87        i2 * params.stride_idx2 +
 88        i3 * params.stride_idx3;
 89
 90    let row_out = params.offset_out +
 91        i1 * params.stride_out1 +
 92        i2 * params.stride_out2 +
 93        i3 * params.stride_out3;
 94
 95
 96    var low: u32 = select(0, k0 - len1, k0 > len1);
 97    var high: u32 = min(k0, len0);
 98
 99    while (low < high) {
100        let mid = (low + high) >> 1;
101        let idx0 = idx_in[row_in + start + mid];
102        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
103        if (take_left(idx0, idx1, row_src)) {
104            low = mid + 1;
105        } else {
106            high = mid;
107        }
108    }
109
110    var i = low;
111    var j = k0 - i;
112    var k = k0;
113    while (k < k1) {
114        var take_l = false;
115        if (i >= len0) {
116            take_l = false;
117        } else if (j >= len1) {
118            take_l = true;
119        } else {
120            let idx0 = idx_in[row_in + start + i];
121            let idx1 = idx_in[row_in + start + params.len + j];
122            take_l = take_left(idx0, idx1, row_src);
123        }
124
125        let out_idx = select(
126            idx_in[row_in + start + params.len + j],
127            idx_in[row_in + start + i],
128            take_l);
129        idx_out[row_out + start + k] = out_idx;
130        i = select(i, i + 1, take_l);
131        j = select(j + 1, j, take_l);
132        k += 1;
133    }
134}