1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4#extension GL_EXT_shader_16bit_storage : require
5#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
6
7#extension GL_EXT_integer_dot_product : require
8
9#ifdef FLOAT16
10#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
11#endif
12
13#if defined(MUL_MAT_ID_USE_SUBGROUPS)
14#extension GL_KHR_shader_subgroup_basic : enable
15#extension GL_KHR_shader_subgroup_ballot : enable
16#endif
17
18#ifdef MUL_MAT_ID
19#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
20#endif
21
22#include "types.glsl"
23
24layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
25
26layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
27#if defined(A_TYPE_PACKED16)
28layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
29#endif
30#if defined(A_TYPE_PACKED32)
31layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
32#endif
33layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
34layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
35
36#ifdef MUL_MAT_ID
37layout (binding = 3) readonly buffer IDS {int data_ids[];};
38layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
39#endif
40
41layout (push_constant) uniform parameter
42{
43 uint M;
44 uint N;
45 uint K;
46 uint stride_a;
47 uint stride_b;
48 uint stride_d;
49
50 uint batch_stride_a;
51 uint batch_stride_b;
52 uint batch_stride_d;
53
54#ifdef MUL_MAT_ID
55 uint nei0;
56 uint nei1;
57 uint nbi1;
58 uint ne11;
59#else
60 uint k_split;
61 uint ne02;
62 uint ne12;
63 uint broadcast2;
64 uint broadcast3;
65#endif
66} p;
67
68layout (constant_id = 0) const uint BLOCK_SIZE = 64;
69layout (constant_id = 1) const uint BM = 64;
70layout (constant_id = 2) const uint BN = 64;
71// layout (constant_id = 3) const uint BK = 32;
72layout (constant_id = 4) const uint WM = 32;
73layout (constant_id = 5) const uint WN = 32;
74layout (constant_id = 6) const uint WMITER = 2;
75layout (constant_id = 7) const uint TM = 4;
76layout (constant_id = 8) const uint TN = 2;
77layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
78layout (constant_id = 10) const uint WARP = 32;
79
80#define BK 32
81
82#include "mul_mmq_shmem_types.glsl"
83
84#ifdef MUL_MAT_ID
85#define BK_STEP 1
86#else
87#ifndef BK_STEP
88#define BK_STEP 4
89#endif
90#endif
91
92// Shared memory cache
93shared block_a_cache buf_a[BM * BK_STEP];
94shared block_b_cache buf_b[BN * BK_STEP];
95// Register cache
96block_a_cache cache_a[WMITER * TM];
97block_b_cache cache_b;
98
99#define LOAD_VEC_A (4 * QUANT_R_MMQ)
100#define LOAD_VEC_B 16
101
102#define NUM_WARPS (BLOCK_SIZE / WARP)
103
104#include "mul_mm_id_funcs.glsl"
105#include "mul_mmq_funcs.glsl"
106
107void main() {
108 const uint ic = gl_WorkGroupID.y;
109
110#ifdef MUL_MAT_ID
111 const uint expert_idx = gl_GlobalInvocationID.z;
112 if (ic * BN >= data_expert_count[expert_idx]) {
113 return;
114 }
115#endif
116#ifdef NEEDS_INIT_IQ_SHMEM
117 init_iq_shmem(gl_WorkGroupSize);
118#endif
119
120#ifndef MUL_MAT_ID
121 const uint batch_idx = gl_GlobalInvocationID.z;
122
123 const uint i13 = batch_idx / p.ne12;
124 const uint i12 = batch_idx % p.ne12;
125
126 const uint i03 = i13 / p.broadcast3;
127 const uint i02 = i12 / p.broadcast2;
128
129 const uint batch_idx_a = i03 * p.ne02 + i02;
130#endif
131
132 const uint blocks_m = (p.M + BM - 1) / BM;
133 const uint ir = gl_WorkGroupID.x % blocks_m;
134 const uint ik = gl_WorkGroupID.x / blocks_m;
135
136 const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
137 const uint WSUBM = WM / WMITER;
138 const uint WSUBN = WN / WNITER;
139 const uint warp_i = gl_LocalInvocationID.x / WARP;
140
141 const uint tiw = gl_LocalInvocationID.x % WARP;
142
143 const uint tiwr = tiw % (WSUBM / TM);
144 const uint tiwc = tiw / (WSUBM / TM);
145
146 const uint warp_r = warp_i % (BM / WM);
147 const uint warp_c = warp_i / (BM / WM);
148
149 const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
150 const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
151 const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
152 const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
153
154 const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
155 const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
156
157#ifdef MUL_MAT_ID
158#ifdef MUL_MAT_ID_USE_SUBGROUPS
159 if (bitCount(p.nei0) == 1) {
160 load_row_ids(expert_idx, true, ic);
161 } else {
162 load_row_ids(expert_idx, false, ic);
163 }
164#else
165 _ne1 = 0;
166 for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
167 for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
168 if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
169 if (_ne1 >= ic * BN) {
170 row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
171 }
172 _ne1++;
173 }
174 }
175 }
176
177 barrier();
178#endif
179
180 // Workgroup has no work
181 if (ic * BN >= _ne1) return;
182#endif
183
184#ifdef MUL_MAT_ID
185 const uint start_k = 0;
186 const uint end_k = p.K;
187#else
188 const uint start_k = ik * p.k_split;
189 const uint end_k = min(p.K, (ik + 1) * p.k_split);
190#endif
191
192 uint pos_a_ib =
193#ifdef MUL_MAT_ID
194 expert_idx * (p.batch_stride_a / BK) +
195#else
196 batch_idx_a * (p.batch_stride_a / BK) +
197#endif
198 (ir * BM * p.stride_a + start_k) / BK;
199#ifdef MUL_MAT_ID
200 uint pos_b_ib = 0;
201#else
202 uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
203#endif
204
205 ACC_TYPE sums[WMITER * TM * WNITER * TN];
206
207 [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
208 sums[i] = ACC_TYPE(0.0f);
209 }
210
211 for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
212 [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
213 const uint buf_ib = loadc_a + l;
214 const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
215 const uint iqs = loadr_a;
216
217 [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
218 if (block + k_step * BK < end_k) {
219 block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
220 }
221 }
222 }
223 [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
224 const uint buf_ib = loadc_b + l;
225
226#ifdef MUL_MAT_ID
227 const u16vec2 row_idx = row_ids[buf_ib];
228 const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
229#else
230 const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
231#endif
232 const uint iqs = loadr_b;
233
234 [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
235 block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs, block + k_step * BK < end_k);
236 }
237 }
238
239 barrier();
240
241 pos_a_ib += BK_STEP;
242 pos_b_ib += BK_STEP;
243
244 for (uint k_step = 0; k_step < BK_STEP; k_step++) {
245 // Load from shared into cache
246 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
247 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
248 const uint reg_ib = wsir * TM + cr;
249 const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
250
251 block_a_to_registers(reg_ib, k_step * BM + buf_ib);
252 }
253 }
254
255 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
256 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
257 const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
258 block_b_to_registers(ib);
259
260 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
261 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
262 const uint cache_a_idx = wsir * TM + cr;
263 const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
264
265 sums[sums_idx] += mmq_dot_product(cache_a_idx);
266 }
267 }
268 }
269 }
270 }
271
272 barrier();
273 }
274
275 const uint dr = ir * BM + warp_r * WM;
276 const uint dc = ic * BN + warp_c * WN;
277
278#ifndef MUL_MAT_ID
279 const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
280#endif
281
282 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
283 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
284
285 const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
286 const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
287 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
288#ifdef MUL_MAT_ID
289 const uint row_i = dc_warp + cc;
290 if (row_i >= _ne1) break;
291
292 const u16vec2 row_idx = row_ids[row_i - ic * BN];
293#endif // MUL_MAT_ID
294 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
295 const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
296#ifdef MUL_MAT_ID
297 if (dr_warp + cr < p.M) {
298 data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
299 }
300#else
301 if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
302 data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
303 }
304#endif // MUL_MAT_ID
305 }
306 }
307 }
308 }
309}