summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
blob: 9a77f6eca740df8e9e75c1b78e813aaf6eed312f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;

@group(0) @binding(1)
var<storage, read_write> idx_in: array<i32>;

@group(0) @binding(2)
var<storage, read_write> idx_out: array<i32>;

struct Params {
    offset_src: u32, // in elements
    offset_in: u32,  // in elements
    offset_out: u32, // in elements

    stride_src1: u32,
    stride_src2: u32,
    stride_src3: u32,

    stride_idx1: u32,
    stride_idx2: u32,
    stride_idx3: u32,

    stride_out1: u32,
    stride_out2: u32,
    stride_out3: u32,

    ne0: u32,
    ne1: u32,
    ne2: u32,

    top_k: u32,

    len: u32,
    nm: u32,
    nrows: u32
};

@group(0) @binding(3)
var<uniform> params: Params;

fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
    let a_val = src[row_base + u32(a_idx)];
    let b_val = src[row_base + u32(b_idx)];
#if ORDER == 0
    return a_val <= b_val;
#else
    return a_val >= b_val;
#endif
}

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
        @builtin(num_workgroups) num_wg: vec3<u32>,
        @builtin(local_invocation_id) lid: vec3<u32>) {
    let linear = wid.x + wid.y * num_wg.x;
    // guard against overprovisioned workgroups
    if (linear >= params.nm * params.nrows) {
        return;
    }

    let start = (linear % params.nm) * params.len * 2;
    let len0 = min(params.len, params.ne0 - start);
    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
    let len1 = min(params.len, rem1);
    let total = len0 + len1;
    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
    let k0 = lid.x * chunk;
    let k1 = min(min(k0 + chunk, total), params.top_k);
    // guard against overprovisioned threads
    if (k0 >= params.top_k || k0 >= total) {
        return;
    }

    var row = linear / params.nm;
    let i3 = row / (params.ne2 * params.ne1);
    row = row % (params.ne2 * params.ne1);
    let i2 = row / params.ne1;
    let i1 = row % params.ne1;

    let row_src = params.offset_src +
        i1 * params.stride_src1 +
        i2 * params.stride_src2 +
        i3 * params.stride_src3;

    let row_in = params.offset_in +
        i1 * params.stride_idx1 +
        i2 * params.stride_idx2 +
        i3 * params.stride_idx3;

    let row_out = params.offset_out +
        i1 * params.stride_out1 +
        i2 * params.stride_out2 +
        i3 * params.stride_out3;


    var low: u32 = select(0, k0 - len1, k0 > len1);
    var high: u32 = min(k0, len0);

    while (low < high) {
        let mid = (low + high) >> 1;
        let idx0 = idx_in[row_in + start + mid];
        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
        if (take_left(idx0, idx1, row_src)) {
            low = mid + 1;
        } else {
            high = mid;
        }
    }

    var i = low;
    var j = k0 - i;
    var k = k0;
    while (k < k1) {
        var take_l = false;
        if (i >= len0) {
            take_l = false;
        } else if (j >= len1) {
            take_l = true;
        } else {
            let idx0 = idx_in[row_in + start + i];
            let idx1 = idx_in[row_in + start + params.len + j];
            take_l = take_left(idx0, idx1, row_src);
        }

        let out_idx = select(
            idx_in[row_in + start + params.len + j],
            idx_in[row_in + start + i],
            take_l);
        idx_out[row_out + start + k] = out_idx;
        i = select(i, i + 1, take_l);
        j = select(j + 1, j, take_l);
        k += 1;
    }
}