1#extension GL_EXT_control_flow_attributes : enable
  2#extension GL_EXT_shader_16bit_storage : require
  3#extension GL_EXT_shader_8bit_storage : require
  4
  5#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
  6#extension GL_KHR_shader_subgroup_basic : require
  7#extension GL_KHR_shader_subgroup_arithmetic : require
  8#endif
  9
 10#ifdef MUL_MAT_ID
 11#define EXPERT_COUNT 8
 12#endif
 13
 14#include "mul_mat_vec_iface.glsl"
 15
 16layout (push_constant) uniform parameter
 17{
 18    uint ncols;
 19    uint stride_a;
 20    uint stride_b;
 21    uint stride_d;
 22
 23    uint batch_stride_a;
 24    uint batch_stride_b;
 25    uint batch_stride_d;
 26
 27    uint fusion_flags;
 28
 29#ifdef MUL_MAT_ID
 30    uint nei0;
 31    uint ne11;
 32    uint expert_i1;
 33    uint nbi1;
 34#else
 35    uint ne02;
 36    uint ne12;
 37    uint broadcast2;
 38    uint broadcast3;
 39#endif
 40} p;
 41
 42#ifdef MUL_MAT_ID
 43uint expert_id;
 44#endif
 45
 46void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 47#ifdef MUL_MAT_ID
 48    const uint expert_i0 = gl_GlobalInvocationID.y;
 49#else
 50    const uint batch_idx = gl_GlobalInvocationID.y;
 51#endif
 52
 53#ifndef MUL_MAT_ID
 54    uint batch_idx_a = 0;
 55    if (batch_idx != 0) {
 56        const uint i13 = batch_idx / p.ne12;
 57        const uint i12 = batch_idx % p.ne12;
 58
 59        const uint i03 = i13 / p.broadcast3;
 60        const uint i02 = i12 / p.broadcast2;
 61
 62        batch_idx_a = i03 * p.ne02 + i02;
 63    }
 64#else
 65    expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
 66#endif
 67
 68    a_offset =
 69#ifdef MUL_MAT_ID
 70            expert_id * (p.batch_stride_a / QUANT_K);
 71#else
 72            batch_idx_a * (p.batch_stride_a / QUANT_K);
 73#endif
 74    b_offset =
 75#ifdef MUL_MAT_ID
 76            (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
 77#else
 78            batch_idx * p.batch_stride_b;
 79#endif
 80    d_offset =
 81#ifdef MUL_MAT_ID
 82            expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
 83#else
 84            batch_idx * p.batch_stride_d;
 85#endif
 86}
 87
 88layout (constant_id = 0) const uint BLOCK_SIZE = 32;
 89layout (constant_id = 1) const uint NUM_ROWS = 1;
 90layout (constant_id = 2) const uint NUM_COLS = 1;
 91
 92#ifdef USE_SUBGROUP_ADD_NO_SHMEM
 93void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
 94    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
 95        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 96            temp[j][n] = subgroupAdd(temp[j][n]);
 97        }
 98    }
 99
100    if (tid == 0) {
101        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
102            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
103#ifdef MUL_MAT_ID
104                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
105                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
106                }
107                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
108                    const uint expert_i0 = gl_GlobalInvocationID.y;
109                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
110                }
111                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
112                    const uint expert_i0 = gl_GlobalInvocationID.y;
113                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
114                }
115#else
116                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
117                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
118                }
119                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
120                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
121                }
122#endif
123                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
124            }
125        }
126    }
127}
128#else
129shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
130
131void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
132    // subgroupAdd is probably faster on devices that support it,
133    // particularly when the workgroup has more than one subgroup
134#if USE_SUBGROUP_ADD
135    // sum up partial sums within a subgroup
136    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
137        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
138            temp[j][n] = subgroupAdd(temp[j][n]);
139        }
140    }
141
142    // Go through shared memory to sum partials across subgroups
143    if (gl_SubgroupInvocationID == 0) {
144        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
145            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
146                tmpsh[j][n][gl_SubgroupID] = temp[j][n];
147            }
148        }
149    }
150    barrier();
151    if (tid == 0) {
152        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
153            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
154                temp[j][n] = FLOAT_TYPE(0);
155                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
156                    temp[j][n] += tmpsh[j][n][s];
157                }
158#ifdef MUL_MAT_ID
159                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
160                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
161                }
162                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
163                    const uint expert_i0 = gl_GlobalInvocationID.y;
164                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
165                }
166                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
167                    const uint expert_i0 = gl_GlobalInvocationID.y;
168                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
169                }
170#else
171                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
172                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
173                }
174                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
175                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
176                }
177#endif
178                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
179            }
180        }
181    }
182#else
183    // sum up partial sums and write back result
184    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
185        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
186            tmpsh[j][n][tid] = temp[j][n];
187        }
188    }
189    barrier();
190    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
191        if (tid < s) {
192            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
193                [[unroll]] for (uint n = 0; n < num_rows; ++n) {
194                    tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
195                }
196            }
197        }
198        barrier();
199    }
200    if (tid == 0) {
201        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
202            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
203#ifdef MUL_MAT_ID
204                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
205                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
206                }
207                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
208                    const uint expert_i0 = gl_GlobalInvocationID.y;
209                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
210                }
211                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
212                    const uint expert_i0 = gl_GlobalInvocationID.y;
213                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
214                }
215#else
216                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
217                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
218                }
219                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
220                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
221                }
222#endif
223                data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
224            }
225        }
226    }
227#endif
228}
229#endif