summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp
blob: 06efd7d9fb43644ca3bd81dc058bd13ec420de3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#version 450

#include "soft_max_large_common.glsl"

shared FLOAT_TYPE sumsh[BLOCK_SIZE];

void main() {
    const uint tid = gl_LocalInvocationID.x;
    const uint rowx = gl_WorkGroupID.y;
    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;

    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
    const uint32_t i01 = rowx % p.ne01;

    uint rowy_start = 0;
    if (p.KY > 0) {
        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
    }

    if (rowx >= p.nrows_x) {
        return;
    }

    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);

    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
        if (i + tid < gl_NumWorkGroups.x) {
            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
            sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
        }
    }

    // reduce across the workgroup
    vals[tid] = max_val;
    sumsh[tid] = sum;
    barrier();
    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
        if (tid < s) {
            vals[tid] = max(max_val, vals[tid + s]);
            sumsh[tid] += sumsh[tid + s];
        }
        barrier();
    }

    max_val = vals[0];
    sum = sumsh[0];

    if (p.has_sinks != 0) {
        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
    }

    FLOAT_TYPE rcpdivisor = 1.0/sum;

    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
        const uint col = col0 + tid;

        if (col >= p.KX) {
            continue;
        }

        data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
    }
}