1#version 450
2
3#include "types.glsl"
4#include "sum_rows.glsl"
5
6#extension GL_EXT_control_flow_attributes : enable
7#extension GL_KHR_shader_subgroup_arithmetic : enable
8#extension GL_KHR_shader_subgroup_basic : enable
9
10layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
11
12layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
14
15layout (constant_id = 0) const uint BLOCK_SIZE = 128;
16layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
17layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
18
19#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
20
21shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
22shared FLOAT_TYPE last_sum;
23
24void main() {
25 const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
26 const uint tid = gl_LocalInvocationID.x;
27
28 const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
29 const uint i03_offset = i03 * p.ne01*p.ne02;
30 const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
31 const uint i01 = row - i03_offset - i02*p.ne01;
32
33 const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
34 const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
35
36 uint subgroup_id = tid / SUBGROUP_SIZE;
37
38 if (tid == 0) {
39 last_sum = 0;
40 }
41
42 uint col = tid * ELEM_PER_THREAD;
43 uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
44 for (int i = 0; i < num_iter; ++i) {
45 FLOAT_TYPE v[ELEM_PER_THREAD];
46 FLOAT_TYPE thread_sum = 0;
47 [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
48 if (col + j < p.n_cols) {
49 thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
50 }
51 v[j] = thread_sum;
52 }
53
54 thread_sum = subgroupExclusiveAdd(thread_sum);
55 [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
56 v[j] += thread_sum;
57 }
58 // Store the largest partial sum for each subgroup, then add the partials for all
59 // lower subgroups and the final partial sum from the previous iteration.
60 if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
61 partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
62 }
63 barrier();
64 for (int s = 0; s < subgroup_id; ++s) {
65 [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
66 v[j] += partial[s];
67 }
68 }
69 [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
70 v[j] += last_sum;
71 }
72 barrier();
73 if (tid == BLOCK_SIZE - 1) {
74 last_sum = v[ELEM_PER_THREAD - 1];
75 }
76 [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
77 if (col + j < p.n_cols) {
78 data_d[dst_idx + col + j] = D_TYPE(v[j]);
79 }
80 }
81 col += BLOCK_SIZE * ELEM_PER_THREAD;
82 }
83}