1#version 450
2
3#include "soft_max_large_common.glsl"
4
5void main() {
6 const uint tid = gl_LocalInvocationID.x;
7 const uint rowx = gl_WorkGroupID.y;
8 const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
9
10 const uint32_t i03 = rowx / (p.ne01 * p.ne02);
11 const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
12 const uint32_t i01 = rowx % p.ne01;
13
14 uint rowy_start = 0;
15 if (p.KY > 0) {
16 rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
17 }
18
19 if (rowx >= p.nrows_x) {
20 return;
21 }
22
23 float slope = get_slope(rowx);
24
25 // Find max
26 FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
27
28 [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
29 if (i + tid < gl_NumWorkGroups.x) {
30 max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
31 }
32 }
33
34 // reduce across the workgroup
35 vals[tid] = max_val;
36 barrier();
37 [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
38 if (tid < s) {
39 vals[tid] = max(max_val, vals[tid + s]);
40 }
41 barrier();
42 }
43
44 max_val = vals[0];
45 barrier();
46
47 FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
48
49 // Compute sum{exp(x - max)}
50 [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
51 const uint col = col0 + tid;
52
53 if (col >= p.KX) {
54 break;
55 }
56
57 // compute exp(a*scale+b*slope), add it to sum
58 const uint i = rowx * p.KX + col;
59 FLOAT_TYPE val;
60 val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
61 sum += val;
62 data_d[i] = D_TYPE(val);
63 }
64
65 // reduce across the workgroup
66 vals[tid] = sum;
67 barrier();
68 [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
69 if (tid < s) {
70 vals[tid] += vals[tid + s];
71 }
72 barrier();
73 }
74
75 if (tid == 0) {
76 sum = vals[0];
77 data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
78 }
79}