1#version 450
 2
 3#include "soft_max_large_common.glsl"
 4
 5void main() {
 6    const uint tid = gl_LocalInvocationID.x;
 7    const uint rowx = gl_WorkGroupID.y;
 8    const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
 9
10    const uint32_t i03 = rowx / (p.ne01 * p.ne02);
11    const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
12    const uint32_t i01 = rowx % p.ne01;
13
14    uint rowy_start = 0;
15    if (p.KY > 0) {
16        rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
17    }
18
19    if (rowx >= p.nrows_x) {
20        return;
21    }
22
23    float slope = get_slope(rowx);
24
25    // Find max
26    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
27
28    [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
29        const uint col = col0 + tid;
30
31        FLOAT_TYPE a = FLOAT_TYPE(0);
32        if (col < p.KX) {
33            a = data_a[rowx * p.KX + col];
34        }
35
36        FLOAT_TYPE b = FLOAT_TYPE(0);
37        if (p.KY > 0 && col < p.KX) {
38            b = data_b[rowy_start + col];
39        }
40
41        FLOAT_TYPE v = a * p.scale + slope * b;
42
43        if (col < p.KX) {
44            max_val = max(max_val, v);
45        }
46    }
47
48    // reduce across the workgroup
49    vals[tid] = max_val;
50    barrier();
51    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
52        if (tid < s) {
53            vals[tid] = max(vals[tid], vals[tid + s]);
54        }
55        barrier();
56    }
57
58    if (tid == 0) {
59        max_val = vals[0];
60        data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
61    }
62}