aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/tri.cl
blob: 35cdd543bc5c730aded859f2183647d64b1cb316 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

//------------------------------------------------------------------------------
// tri
//------------------------------------------------------------------------------
__kernel void kernel_tri_f32(
        global float * src0,
        ulong offset0,
        global float * dst,
        ulong offsetd,
        int n,
        int ne0,
        int ne1,
        int tri_type
) {
    src0 = (global float*)((global char*)src0 + offset0);
    dst = (global float*)((global char*)dst + offsetd);

    int idx = get_global_id(0);
    if (idx >= n) return;

    int i0 = idx % ne0;
    int i1 = (idx / ne0) % ne1;

    int keep = 0;
    if (tri_type == 0) keep = (i0 >= i1);
    else if (tri_type == 1) keep = (i0 >  i1);
    else if (tri_type == 2) keep = (i0 <= i1);
    else                    keep = (i0 <  i1);

    dst[idx] = keep ? src0[idx] : 0.0f;
}