1#version 450
 2
 3#include "soft_max_large_common.glsl"
 4
 5shared FLOAT_TYPE sumsh[BLOCK_SIZE];
 6
 7void main() {
 8    const uint tid = gl_LocalInvocationID.x;
 9    const uint rowx = gl_WorkGroupID.y;
10    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
11
12    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
13    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
14    const uint32_t i01 = rowx % p.ne01;
15
16    uint rowy_start = 0;
17    if (p.KY > 0) {
18        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
19    }
20
21    if (rowx >= p.nrows_x) {
22        return;
23    }
24
25    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
26    FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
27
28    [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
29        if (i + tid < gl_NumWorkGroups.x) {
30            max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
31            sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
32        }
33    }
34
35    // reduce across the workgroup
36    vals[tid] = max_val;
37    sumsh[tid] = sum;
38    barrier();
39    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
40        if (tid < s) {
41            vals[tid] = max(max_val, vals[tid + s]);
42            sumsh[tid] += sumsh[tid + s];
43        }
44        barrier();
45    }
46
47    max_val = vals[0];
48    sum = sumsh[0];
49
50    if (p.has_sinks != 0) {
51        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
52    }
53
54    FLOAT_TYPE rcpdivisor = 1.0/sum;
55
56    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
57        const uint col = col0 + tid;
58
59        if (col >= p.KX) {
60            continue;
61        }
62
63        data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
64    }
65}