1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  6
  7#extension GL_EXT_integer_dot_product : require
  8
  9#ifdef FLOAT16
 10#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 11#endif
 12
 13#if defined(MUL_MAT_ID_USE_SUBGROUPS)
 14#extension GL_KHR_shader_subgroup_basic : enable
 15#extension GL_KHR_shader_subgroup_ballot : enable
 16#endif
 17
 18#ifdef MUL_MAT_ID
 19#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 20#endif
 21
 22#include "types.glsl"
 23
 24layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 25
 26layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 27#if defined(A_TYPE_PACKED16)
 28layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
 29#endif
 30#if defined(A_TYPE_PACKED32)
 31layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
 32#endif
 33layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
 34layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 35
 36#ifdef MUL_MAT_ID
 37layout (binding = 3) readonly buffer IDS {int data_ids[];};
 38layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 39#endif
 40
 41layout (push_constant) uniform parameter
 42{
 43    uint M;
 44    uint N;
 45    uint K;
 46    uint stride_a;
 47    uint stride_b;
 48    uint stride_d;
 49
 50    uint batch_stride_a;
 51    uint batch_stride_b;
 52    uint batch_stride_d;
 53
 54#ifdef MUL_MAT_ID
 55    uint nei0;
 56    uint nei1;
 57    uint nbi1;
 58    uint ne11;
 59#else
 60    uint k_split;
 61    uint ne02;
 62    uint ne12;
 63    uint broadcast2;
 64    uint broadcast3;
 65#endif
 66} p;
 67
 68layout (constant_id = 0) const uint BLOCK_SIZE = 64;
 69layout (constant_id = 1) const uint BM = 64;
 70layout (constant_id = 2) const uint BN = 64;
 71// layout (constant_id = 3) const uint BK = 32;
 72layout (constant_id = 4) const uint WM = 32;
 73layout (constant_id = 5) const uint WN = 32;
 74layout (constant_id = 6) const uint WMITER = 2;
 75layout (constant_id = 7) const uint TM = 4;
 76layout (constant_id = 8) const uint TN = 2;
 77layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
 78layout (constant_id = 10) const uint WARP = 32;
 79
 80#define BK 32
 81
 82#include "mul_mmq_shmem_types.glsl"
 83
 84#ifdef MUL_MAT_ID
 85#define BK_STEP 1
 86#else
 87#ifndef BK_STEP
 88#define BK_STEP 4
 89#endif
 90#endif
 91
 92// Shared memory cache
 93shared block_a_cache buf_a[BM * BK_STEP];
 94shared block_b_cache buf_b[BN * BK_STEP];
 95// Register cache
 96block_a_cache cache_a[WMITER * TM];
 97block_b_cache cache_b;
 98
 99#define LOAD_VEC_A (4 * QUANT_R_MMQ)
100#define LOAD_VEC_B 16
101
102#define NUM_WARPS (BLOCK_SIZE / WARP)
103
104#include "mul_mm_id_funcs.glsl"
105#include "mul_mmq_funcs.glsl"
106
107void main() {
108    const uint ic = gl_WorkGroupID.y;
109
110#ifdef MUL_MAT_ID
111    const uint expert_idx = gl_GlobalInvocationID.z;
112    if (ic * BN >= data_expert_count[expert_idx]) {
113        return;
114    }
115#endif
116#ifdef NEEDS_INIT_IQ_SHMEM
117    init_iq_shmem(gl_WorkGroupSize);
118#endif
119
120#ifndef MUL_MAT_ID
121    const uint batch_idx = gl_GlobalInvocationID.z;
122
123    const uint i13 = batch_idx / p.ne12;
124    const uint i12 = batch_idx % p.ne12;
125
126    const uint i03 = i13 / p.broadcast3;
127    const uint i02 = i12 / p.broadcast2;
128
129    const uint batch_idx_a = i03 * p.ne02 + i02;
130#endif
131
132    const uint blocks_m = (p.M + BM - 1) / BM;
133    const uint ir = gl_WorkGroupID.x % blocks_m;
134    const uint ik = gl_WorkGroupID.x / blocks_m;
135
136    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
137    const uint WSUBM = WM / WMITER;
138    const uint WSUBN = WN / WNITER;
139    const uint warp_i = gl_LocalInvocationID.x / WARP;
140
141    const uint tiw = gl_LocalInvocationID.x % WARP;
142
143    const uint tiwr = tiw % (WSUBM / TM);
144    const uint tiwc = tiw / (WSUBM / TM);
145
146    const uint warp_r = warp_i % (BM / WM);
147    const uint warp_c = warp_i / (BM / WM);
148
149    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
150    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
151    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
152    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
153
154    const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
155    const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
156
157#ifdef MUL_MAT_ID
158#ifdef MUL_MAT_ID_USE_SUBGROUPS
159    if (bitCount(p.nei0) == 1) {
160        load_row_ids(expert_idx, true, ic);
161    } else {
162        load_row_ids(expert_idx, false, ic);
163    }
164#else
165    _ne1 = 0;
166    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
167        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
168            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
169                if (_ne1 >= ic * BN) {
170                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
171                }
172                _ne1++;
173            }
174        }
175    }
176
177    barrier();
178#endif
179
180    // Workgroup has no work
181    if (ic * BN >= _ne1) return;
182#endif
183
184#ifdef MUL_MAT_ID
185    const uint start_k = 0;
186    const uint end_k = p.K;
187#else
188    const uint start_k = ik * p.k_split;
189    const uint end_k = min(p.K, (ik + 1) * p.k_split);
190#endif
191
192    uint pos_a_ib =
193#ifdef MUL_MAT_ID
194        expert_idx * (p.batch_stride_a / BK) +
195#else
196        batch_idx_a * (p.batch_stride_a / BK) +
197#endif
198        (ir * BM * p.stride_a + start_k) / BK;
199#ifdef MUL_MAT_ID
200    uint pos_b_ib = 0;
201#else
202    uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
203#endif
204
205    ACC_TYPE sums[WMITER * TM * WNITER * TN];
206
207    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
208        sums[i] = ACC_TYPE(0.0f);
209    }
210
211    for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
212        [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
213            const uint buf_ib = loadc_a + l;
214            const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
215            const uint iqs = loadr_a;
216
217            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
218                if (block + k_step * BK < end_k) {
219                    block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
220                }
221            }
222        }
223        [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
224            const uint buf_ib = loadc_b + l;
225
226#ifdef MUL_MAT_ID
227            const u16vec2 row_idx = row_ids[buf_ib];
228            const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
229#else
230            const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
231#endif
232            const uint iqs = loadr_b;
233
234            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
235                block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs, block + k_step * BK < end_k);
236            }
237        }
238
239        barrier();
240
241        pos_a_ib += BK_STEP;
242        pos_b_ib += BK_STEP;
243
244        for (uint k_step = 0; k_step < BK_STEP; k_step++) {
245            // Load from shared into cache
246            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
247                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
248                    const uint reg_ib = wsir * TM + cr;
249                    const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
250
251                    block_a_to_registers(reg_ib, k_step * BM + buf_ib);
252                }
253            }
254
255            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
256                [[unroll]] for (uint cc = 0; cc < TN; cc++) {
257                    const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
258                    block_b_to_registers(ib);
259
260                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
261                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
262                            const uint cache_a_idx = wsir * TM + cr;
263                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
264
265                            sums[sums_idx] += mmq_dot_product(cache_a_idx);
266                        }
267                    }
268                }
269            }
270        }
271
272        barrier();
273    }
274
275    const uint dr = ir * BM + warp_r * WM;
276    const uint dc = ic * BN + warp_c * WN;
277
278#ifndef MUL_MAT_ID
279    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
280#endif
281
282    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
283        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
284
285            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
286            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
287            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
288#ifdef MUL_MAT_ID
289                const uint row_i = dc_warp + cc;
290                if (row_i >= _ne1) break;
291
292                const u16vec2 row_idx = row_ids[row_i - ic * BN];
293#endif // MUL_MAT_ID
294                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
295                    const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
296#ifdef MUL_MAT_ID
297                    if (dr_warp + cr < p.M) {
298                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
299                    }
300#else
301                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
302                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
303                    }
304#endif // MUL_MAT_ID
305                }
306            }
307        }
308    }
309}