1#version 450
 2
 3#include "types.glsl"
 4#include "sum_rows.glsl"
 5
 6#extension GL_EXT_control_flow_attributes : enable
 7#extension GL_KHR_shader_subgroup_arithmetic : enable
 8#extension GL_KHR_shader_subgroup_basic : enable
 9
10layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11
12layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
14
15layout (constant_id = 0) const uint BLOCK_SIZE = 128;
16layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
17layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
18
19#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
20
21shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
22shared FLOAT_TYPE last_sum;
23
24void main() {
25    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
26    const uint tid = gl_LocalInvocationID.x;
27
28    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
29    const uint i03_offset = i03 * p.ne01*p.ne02;
30    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
31    const uint i01 = row - i03_offset - i02*p.ne01;
32
33    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
34    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
35
36    uint subgroup_id = tid / SUBGROUP_SIZE;
37
38    if (tid == 0) {
39        last_sum = 0;
40    }
41
42    uint col = tid * ELEM_PER_THREAD;
43    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
44    for (int i = 0; i < num_iter; ++i) {
45        FLOAT_TYPE v[ELEM_PER_THREAD];
46        FLOAT_TYPE thread_sum = 0;
47        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
48            if (col + j < p.n_cols) {
49                thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
50            }
51            v[j] = thread_sum;
52        }
53
54        thread_sum = subgroupExclusiveAdd(thread_sum);
55        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
56            v[j] += thread_sum;
57        }
58        // Store the largest partial sum for each subgroup, then add the partials for all
59        // lower subgroups and the final partial sum from the previous iteration.
60        if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
61            partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
62        }
63        barrier();
64        for (int s = 0; s < subgroup_id; ++s) {
65            [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
66                v[j] += partial[s];
67            }
68        }
69        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
70            v[j] += last_sum;
71        }
72        barrier();
73        if (tid == BLOCK_SIZE - 1) {
74            last_sum = v[ELEM_PER_THREAD - 1];
75        }
76        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
77            if (col + j < p.n_cols) {
78                data_d[dst_idx + col + j] = D_TYPE(v[j]);
79            }
80        }
81        col += BLOCK_SIZE * ELEM_PER_THREAD;
82    }
83}