1#version 450
 2
 3#include "rte.glsl"
 4#include "types.glsl"
 5#include "generic_unary_head.glsl"
 6
 7#define GGML_TRI_TYPE_UPPER_DIAG 0
 8#define GGML_TRI_TYPE_UPPER      1
 9#define GGML_TRI_TYPE_LOWER_DIAG 2
10#define GGML_TRI_TYPE_LOWER      3
11
12layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
13
14void main() {
15    const uint idx = get_idx();
16
17    if (idx >= p.ne) {
18        return;
19    }
20
21    const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
22    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
23    const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
24    const uint i02_offset = i02*p.ne01*p.ne00;
25    const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
26    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
27
28    int param = floatBitsToInt(p.param1);
29    bool pass = false;
30    switch (param) {
31    case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
32    case GGML_TRI_TYPE_UPPER:      pass = i00 >  i01; break;
33    case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
34    case GGML_TRI_TYPE_LOWER:      pass = i00 <  i01; break;
35    }
36
37    if (pass) {
38        const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
39        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
40    } else {
41        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
42    }
43}