1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#ifdef FLOAT16
  7#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  8#endif
  9#if defined(DATA_A_IQ1_M)
 10#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 11#endif
 12
 13#if defined(DATA_A_BF16) && defined(COOPMAT)
 14#extension GL_EXT_bfloat16 : enable
 15#endif
 16
 17#ifdef COOPMAT
 18#extension GL_KHR_cooperative_matrix : enable
 19#extension GL_KHR_memory_scope_semantics : enable
 20#endif
 21
 22#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
 23#extension GL_KHR_shader_subgroup_basic : enable
 24#extension GL_KHR_shader_subgroup_ballot : enable
 25#endif
 26
 27#ifdef MUL_MAT_ID
 28#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 29#endif
 30
 31#include "types.glsl"
 32
 33#ifndef LOAD_VEC_A
 34#define LOAD_VEC_A 1
 35#endif
 36#ifndef LOAD_VEC_B
 37#define LOAD_VEC_B 1
 38#endif
 39
 40// Load 2 values at once without affecting index calculations through LOAD_VEC
 41#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
 42#define LOAD_VEC_BATCH_A 2
 43#else
 44#define LOAD_VEC_BATCH_A 1
 45#endif
 46#if !defined(ALIGNED)
 47#define LOAD_VEC_BATCH_B 2
 48#else
 49#define LOAD_VEC_BATCH_B 1
 50#endif
 51
 52#if !defined(TO_FLOAT_TYPE)
 53#define TO_FLOAT_TYPE FLOAT_TYPE
 54#endif
 55
 56layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 57
 58layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 59#if defined(A_TYPE_PACKED16)
 60layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
 61#endif
 62#if defined(A_TYPE_PACKED32)
 63layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
 64#endif
 65
 66layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 67layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 68
 69#ifdef MUL_MAT_ID
 70layout (binding = 3) readonly buffer IDS {int data_ids[];};
 71layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 72#endif
 73
 74layout (push_constant) uniform parameter
 75{
 76    uint M;
 77    uint N;
 78    uint K;
 79    uint stride_a;
 80    uint stride_b;
 81    uint stride_d;
 82
 83    uint batch_stride_a;
 84    uint batch_stride_b;
 85    uint batch_stride_d;
 86
 87#ifdef MUL_MAT_ID
 88    uint nei0;
 89    uint nei1;
 90    uint nbi1;
 91    uint ne11;
 92#else
 93    uint k_split;
 94    uint ne02;
 95    uint ne12;
 96    uint broadcast2;
 97    uint broadcast3;
 98#endif
 99} p;
100
101layout (constant_id = 0) const uint BLOCK_SIZE = 64;
102layout (constant_id = 1) const uint BM = 64;
103layout (constant_id = 2) const uint BN = 64;
104layout (constant_id = 4) const uint WM = 32;
105layout (constant_id = 5) const uint WN = 32;
106layout (constant_id = 6) const uint WMITER = 2;
107layout (constant_id = 7) const uint TM = 4;
108layout (constant_id = 8) const uint TN = 2;
109layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
110layout (constant_id = 10) const uint WARP = 32;
111
112#if defined(DATA_A_F32) || defined(DATA_A_F16)
113#define BK 32
114#define BK_STEP 4
115#else
116layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
117#define BK_STEP 2
118#endif
119
120#ifdef COOPMAT
121#define SHMEM_STRIDE (BK / 2 + 4)
122#else
123#define SHMEM_STRIDE (BK / 2 + 1)
124#endif
125
126shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
127shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
128
129#define NUM_WARPS (BLOCK_SIZE / WARP)
130
131#ifdef COOPMAT
132shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
133#endif
134
135#include "mul_mm_id_funcs.glsl"
136#include "mul_mm_funcs.glsl"
137
138void main() {
139    const uint ic = gl_WorkGroupID.y;
140
141#ifdef MUL_MAT_ID
142    const uint expert_idx = gl_GlobalInvocationID.z;
143    if (ic * BN >= data_expert_count[expert_idx]) {
144        return;
145    }
146#endif
147#ifdef NEEDS_INIT_IQ_SHMEM
148    init_iq_shmem(gl_WorkGroupSize);
149#endif
150
151#ifndef MUL_MAT_ID
152    const uint batch_idx = gl_GlobalInvocationID.z;
153
154    const uint i13 = batch_idx / p.ne12;
155    const uint i12 = batch_idx % p.ne12;
156
157    const uint i03 = i13 / p.broadcast3;
158    const uint i02 = i12 / p.broadcast2;
159
160    const uint batch_idx_a = i03 * p.ne02 + i02;
161#endif
162
163    const uint blocks_m = (p.M + BM - 1) / BM;
164    const uint ir = gl_WorkGroupID.x % blocks_m;
165    const uint ik = gl_WorkGroupID.x / blocks_m;
166
167    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
168    const uint WSUBM = WM / WMITER;
169    const uint WSUBN = WN / WNITER;
170
171#ifdef COOPMAT
172    const uint warp_i = gl_SubgroupID;
173
174    const uint tiw = gl_SubgroupInvocationID;
175
176    const uint cms_per_row = WM / TM;
177    const uint cms_per_col = WN / TN;
178
179    const uint storestride = WARP / TM;
180    const uint store_r = tiw % TM;
181    const uint store_c = tiw / TM;
182#else
183    const uint warp_i = gl_LocalInvocationID.x / WARP;
184
185    const uint tiw = gl_LocalInvocationID.x % WARP;
186
187    const uint tiwr = tiw % (WSUBM / TM);
188    const uint tiwc = tiw / (WSUBM / TM);
189#endif
190
191    const uint warp_r = warp_i % (BM / WM);
192    const uint warp_c = warp_i / (BM / WM);
193
194    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
195    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
196    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
197    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
198
199    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
200    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
201
202#ifdef MUL_MAT_ID
203#ifdef MUL_MAT_ID_USE_SUBGROUPS
204    if (bitCount(p.nei0) == 1) {
205        load_row_ids(expert_idx, true, ic);
206    } else {
207        load_row_ids(expert_idx, false, ic);
208    }
209#else
210    _ne1 = 0;
211    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
212        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
213            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
214                if (_ne1 >= ic * BN) {
215                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
216                }
217                _ne1++;
218            }
219        }
220    }
221
222    barrier();
223#endif
224
225    // Workgroup has no work
226    if (ic * BN >= _ne1) return;
227#endif
228
229#ifdef MUL_MAT_ID
230    const uint start_k = 0;
231    const uint end_k = p.K;
232#else
233    const uint start_k = ik * p.k_split;
234    const uint end_k = min(p.K, (ik + 1) * p.k_split);
235#endif
236
237    uint pos_a =
238#ifdef MUL_MAT_ID
239        expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
240#else
241        batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
242#endif
243        (ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
244#ifdef MUL_MAT_ID
245    uint pos_b = 0;
246#else
247    uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
248#endif
249
250#ifdef COOPMAT
251    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
252    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
253    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
254
255    [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
256        sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
257    }
258#else
259    ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
260#if defined(DATA_A_F32) || defined(DATA_A_F16)
261    FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
262    FLOAT_TYPE_VEC4 cache_b;
263#else
264    FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
265    FLOAT_TYPE_VEC2 cache_b;
266#endif
267
268    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
269        sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
270    }
271#endif
272
273    for (uint block = start_k; block < end_k; block += BK) {
274        [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
275            load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
276        }
277        [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
278#if !defined(MUL_MAT_ID)
279            load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
280#else
281            load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
282#endif
283        }
284
285        barrier();
286
287        pos_a += BK / LOAD_VEC_A;
288        pos_b += BK / LOAD_VEC_B;
289
290#ifdef COOPMAT
291        [[unroll]] for (uint i = 0; i < BK; i += TK) {
292            [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
293                // Load from shared into cache
294                coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
295
296                [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
297                    coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
298
299                    sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
300                }
301            }
302        }
303#else
304        [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {
305            // Load from shared into cache
306            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
307                [[unroll]] for (uint j = 0; j < TM; j++) {
308                #if defined(DATA_A_F32) || defined(DATA_A_F16)
309                    cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i    ];
310                    cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
311                #else
312                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
313                #endif
314                }
315            }
316
317            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
318                [[unroll]] for (uint cc = 0; cc < TN; cc++) {
319                #if defined(DATA_A_F32) || defined(DATA_A_F16)
320                    cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i    ];
321                    cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
322                #else
323                    cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
324                #endif
325
326                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
327                        [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
328                            // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
329                            const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
330                        #if defined(DATA_A_F32) || defined(DATA_A_F16)
331                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y),
332                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
333                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
334                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
335                        #else
336                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
337                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
338                        #endif
339                        }
340                    }
341                }
342            }
343
344        }
345#endif
346
347        barrier();
348    }
349
350#if defined(ACC_TYPE_MAX)
351#ifdef COOPMAT
352    [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
353        [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
354            sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
355        }
356    }
357#else
358    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
359        sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
360        sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
361    }
362#endif
363#endif
364
365    const uint dr = ir * BM + warp_r * WM;
366    const uint dc = ic * BN + warp_c * WN;
367
368#ifndef MUL_MAT_ID
369    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
370#endif
371
372#ifdef COOPMAT
373#ifdef MUL_MAT_ID
374    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
375        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
376            coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
377
378            [[unroll]] for (uint col = 0; col < TN; col += storestride) {
379                const uint row_i = dc + cm_col * TN + col + store_c;
380                if (row_i >= _ne1) break;
381
382                const u16vec2 row_idx = row_ids[row_i - ic * BN];
383
384                if (dr + cm_row * TM + store_r < p.M) {
385                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
386                }
387            }
388        }
389    }
390#else
391    const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float
392
393    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
394        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
395            const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
396
397            if (is_aligned && is_in_bounds) {
398                // Full coopMat is within bounds and stride_d is aligned with 16B
399                coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
400                coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
401            } else if (is_in_bounds) {
402                // Full coopMat is within bounds, but stride_d is not aligned
403                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
404
405                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
406                    data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
407                }
408            } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
409                // Partial coopMat is within bounds
410                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
411
412                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
413                    if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
414                        data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
415                    }
416                }
417            }
418        }
419    }
420#endif // MUL_MAT_ID
421#else
422    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
423        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
424
425            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
426            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
427            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
428#ifdef MUL_MAT_ID
429                const uint row_i = dc_warp + cc;
430                if (row_i >= _ne1) break;
431
432                const u16vec2 row_idx = row_ids[row_i - ic * BN];
433#endif // MUL_MAT_ID
434                [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
435                    const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
436#ifdef MUL_MAT_ID
437                    if (dr_warp + 2 * cr < p.M) {
438                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
439                    }
440                    if (dr_warp + 2 * cr + 1 < p.M) {
441                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
442                    }
443#else
444                    if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
445                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
446                    }
447                    if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
448                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
449                    }
450#endif // MUL_MAT_ID
451                }
452            }
453        }
454    }
455#endif // COOPMAT
456}