1#version 450
 2
 3#extension GL_EXT_shader_16bit_storage : require
 4#if ADD_RMS
 5#extension GL_KHR_shader_subgroup_arithmetic : enable
 6#extension GL_KHR_shader_subgroup_basic : enable
 7#endif
 8
 9#include "types.glsl"
10#include "generic_binary_head.glsl"
11
12const uint num_threads = 256;
13
14layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
15
16layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
17
18#if ADD_RMS
19// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20shared FLOAT_TYPE sumsh[num_threads];
21#endif
22
23void main() {
24    uint idx = get_idx();
25    uint orig_idx = idx;
26
27    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
28    const uint num_iter = 2;
29
30    FLOAT_TYPE sum_sq = 0;
31
32    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
33        if (idx >= p.ne) {
34            continue;
35        }
36        uint i00, i01, i02, i03;
37        get_indices(idx, i00, i01, i02, i03);
38
39        FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
40        sum_sq += sum*sum;
41
42        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
43
44        idx += num_threads;
45    }
46
47#if ADD_RMS
48    if (p.param3 != 0) {
49        // reduce the sum within each subgroup, then across subgroups
50        const uint NumSubgroups = num_threads / gl_SubgroupSize;
51        sum_sq = subgroupAdd(sum_sq);
52        if (gl_SubgroupInvocationID == 0) {
53            sumsh[gl_SubgroupID] = sum_sq;
54        }
55        barrier();
56        [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57            if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58                sum_sq += sumsh[gl_SubgroupID + s];
59                sumsh[gl_SubgroupID] = sum_sq;
60            }
61            barrier();
62        }
63
64        if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65            partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
66        }
67    }
68#endif
69}