1#extension GL_EXT_control_flow_attributes : enable
2#extension GL_EXT_shader_16bit_storage : require
3#extension GL_EXT_shader_8bit_storage : require
4
5#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
6#extension GL_KHR_shader_subgroup_basic : require
7#extension GL_KHR_shader_subgroup_arithmetic : require
8#endif
9
10#ifdef MUL_MAT_ID
11#define EXPERT_COUNT 8
12#endif
13
14#include "mul_mat_vec_iface.glsl"
15
16layout (push_constant) uniform parameter
17{
18 uint ncols;
19 uint stride_a;
20 uint stride_b;
21 uint stride_d;
22
23 uint batch_stride_a;
24 uint batch_stride_b;
25 uint batch_stride_d;
26
27 uint fusion_flags;
28
29#ifdef MUL_MAT_ID
30 uint nei0;
31 uint ne11;
32 uint expert_i1;
33 uint nbi1;
34#else
35 uint ne02;
36 uint ne12;
37 uint broadcast2;
38 uint broadcast3;
39#endif
40} p;
41
42#ifdef MUL_MAT_ID
43uint expert_id;
44#endif
45
46void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
47#ifdef MUL_MAT_ID
48 const uint expert_i0 = gl_GlobalInvocationID.y;
49#else
50 const uint batch_idx = gl_GlobalInvocationID.y;
51#endif
52
53#ifndef MUL_MAT_ID
54 uint batch_idx_a = 0;
55 if (batch_idx != 0) {
56 const uint i13 = batch_idx / p.ne12;
57 const uint i12 = batch_idx % p.ne12;
58
59 const uint i03 = i13 / p.broadcast3;
60 const uint i02 = i12 / p.broadcast2;
61
62 batch_idx_a = i03 * p.ne02 + i02;
63 }
64#else
65 expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
66#endif
67
68 a_offset =
69#ifdef MUL_MAT_ID
70 expert_id * (p.batch_stride_a / QUANT_K);
71#else
72 batch_idx_a * (p.batch_stride_a / QUANT_K);
73#endif
74 b_offset =
75#ifdef MUL_MAT_ID
76 (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
77#else
78 batch_idx * p.batch_stride_b;
79#endif
80 d_offset =
81#ifdef MUL_MAT_ID
82 expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
83#else
84 batch_idx * p.batch_stride_d;
85#endif
86}
87
88layout (constant_id = 0) const uint BLOCK_SIZE = 32;
89layout (constant_id = 1) const uint NUM_ROWS = 1;
90layout (constant_id = 2) const uint NUM_COLS = 1;
91
92#ifdef USE_SUBGROUP_ADD_NO_SHMEM
93void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
94 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
95 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
96 temp[j][n] = subgroupAdd(temp[j][n]);
97 }
98 }
99
100 if (tid == 0) {
101 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
102 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
103#ifdef MUL_MAT_ID
104 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
105 temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
106 }
107 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
108 const uint expert_i0 = gl_GlobalInvocationID.y;
109 temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
110 }
111 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
112 const uint expert_i0 = gl_GlobalInvocationID.y;
113 temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
114 }
115#else
116 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
117 temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
118 }
119 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
120 temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
121 }
122#endif
123 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
124 }
125 }
126 }
127}
128#else
129shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
130
131void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
132 // subgroupAdd is probably faster on devices that support it,
133 // particularly when the workgroup has more than one subgroup
134#if USE_SUBGROUP_ADD
135 // sum up partial sums within a subgroup
136 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
137 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
138 temp[j][n] = subgroupAdd(temp[j][n]);
139 }
140 }
141
142 // Go through shared memory to sum partials across subgroups
143 if (gl_SubgroupInvocationID == 0) {
144 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
145 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
146 tmpsh[j][n][gl_SubgroupID] = temp[j][n];
147 }
148 }
149 }
150 barrier();
151 if (tid == 0) {
152 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
153 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
154 temp[j][n] = FLOAT_TYPE(0);
155 [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
156 temp[j][n] += tmpsh[j][n][s];
157 }
158#ifdef MUL_MAT_ID
159 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
160 temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
161 }
162 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
163 const uint expert_i0 = gl_GlobalInvocationID.y;
164 temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
165 }
166 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
167 const uint expert_i0 = gl_GlobalInvocationID.y;
168 temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
169 }
170#else
171 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
172 temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
173 }
174 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
175 temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
176 }
177#endif
178 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
179 }
180 }
181 }
182#else
183 // sum up partial sums and write back result
184 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
185 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
186 tmpsh[j][n][tid] = temp[j][n];
187 }
188 }
189 barrier();
190 [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
191 if (tid < s) {
192 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
193 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
194 tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
195 }
196 }
197 }
198 barrier();
199 }
200 if (tid == 0) {
201 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
202 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
203#ifdef MUL_MAT_ID
204 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
205 tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
206 }
207 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
208 const uint expert_i0 = gl_GlobalInvocationID.y;
209 tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
210 }
211 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
212 const uint expert_i0 = gl_GlobalInvocationID.y;
213 tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
214 }
215#else
216 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
217 tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
218 }
219 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
220 tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
221 }
222#endif
223 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
224 }
225 }
226 }
227#endif
228}
229#endif