#version 450 #include "soft_max_large_common.glsl" shared FLOAT_TYPE sumsh[BLOCK_SIZE]; void main() { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.y; const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; const uint32_t i03 = rowx / (p.ne01 * p.ne02); const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; const uint32_t i01 = rowx % p.ne01; uint rowy_start = 0; if (p.KY > 0) { rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; } if (rowx >= p.nrows_x) { return; } FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; FLOAT_TYPE sum = FLOAT_TYPE(0.0f); [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { if (i + tid < gl_NumWorkGroups.x) { max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; } } // reduce across the workgroup vals[tid] = max_val; sumsh[tid] = sum; barrier(); [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { vals[tid] = max(max_val, vals[tid + s]); sumsh[tid] += sumsh[tid + s]; } barrier(); } max_val = vals[0]; sum = sumsh[0]; if (p.has_sinks != 0) { sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); } FLOAT_TYPE rcpdivisor = 1.0/sum; [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { const uint col = col0 + tid; if (col >= p.KX) { continue; } data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); } }