summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp
blob: cd3f42f4911930110cab3962788de7d1376d335f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#version 450

#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
    const uint idx = get_idx();

    if (idx >= p.ne) {
        return;
    }

    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
    const uint i12_offset = i12*p.ne11*p.ne10;
    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;

    if (i10 == i11) {
        const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
    } else {
        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
    }
}