1#version 450
2
3#include "rte.glsl"
4#include "types.glsl"
5#include "generic_unary_head.glsl"
6
7#define GGML_TRI_TYPE_UPPER_DIAG 0
8#define GGML_TRI_TYPE_UPPER 1
9#define GGML_TRI_TYPE_LOWER_DIAG 2
10#define GGML_TRI_TYPE_LOWER 3
11
12layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
13
14void main() {
15 const uint idx = get_idx();
16
17 if (idx >= p.ne) {
18 return;
19 }
20
21 const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
22 const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
23 const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
24 const uint i02_offset = i02*p.ne01*p.ne00;
25 const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
26 const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
27
28 int param = floatBitsToInt(p.param1);
29 bool pass = false;
30 switch (param) {
31 case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
32 case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
33 case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
34 case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
35 }
36
37 if (pass) {
38 const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
39 data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
40 } else {
41 data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
42 }
43}