1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  7#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  8#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  9#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 10
 11#extension GL_KHR_memory_scope_semantics : enable
 12#extension GL_KHR_cooperative_matrix : enable
 13#extension GL_NV_cooperative_matrix2 : enable
 14#extension GL_EXT_buffer_reference : enable
 15#extension GL_KHR_shader_subgroup_ballot : enable
 16#extension GL_KHR_shader_subgroup_vote : enable
 17#ifdef DATA_A_BF16
 18#extension GL_EXT_bfloat16 : enable
 19#endif
 20
 21#include "types.glsl"
 22#include "utils.glsl"
 23
 24layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 25
 26#define IS_MUL_MM2 1
 27
 28layout (constant_id = 0) const uint BLOCK_SIZE = 256;
 29layout (constant_id = 1) const uint BM = 64;
 30layout (constant_id = 2) const uint BN = 64;
 31layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
 32
 33layout (constant_id = 4) const bool enable_smaller_matrices = false;
 34const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
 35const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
 36
 37layout (push_constant) uniform parameter
 38{
 39    uint M;
 40    uint N;
 41    uint K;
 42    uint stride_a;
 43    uint stride_b;
 44    uint stride_d;
 45
 46    uint batch_stride_a;
 47    uint batch_stride_b;
 48    uint batch_stride_d;
 49
 50#ifdef MUL_MAT_ID
 51    uint nei0;
 52    uint nei1;
 53    uint nbi1;
 54    uint ne11;
 55#else
 56    uint k_split;
 57    uint ne02;
 58    uint ne12;
 59    uint broadcast2;
 60    uint broadcast3;
 61#endif
 62    // N dimension for the B matrix can be >= p.N
 63    uint padded_N;
 64} p;
 65
 66
 67layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 68layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 69layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 70
 71#if QUANT_K > 1
 72#define DECODEFUNCA , dequantFuncA
 73
 74#include "dequant_funcs_cm2.glsl"
 75
 76#else
 77#define DECODEFUNCA
 78#endif
 79
 80#if !defined(fetch_scales)
 81#define fetch_scales(a, b, c, d, e, f)
 82#endif
 83#if !defined(store_scales)
 84#define store_scales(a)
 85#endif
 86
 87#if defined(DATA_A_BF16)
 88#define MAT_TYPE bfloat16_t
 89#else
 90#define MAT_TYPE FLOAT_TYPE
 91#endif
 92
 93#ifdef MUL_MAT_ID
 94layout (binding = 3) readonly buffer IDS {int data_ids[];};
 95layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 96
 97shared u16vec4 row_ids[BN];
 98
 99layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
100   B_TYPE b[];
101};
102
103uint _ne1;
104layout (constant_id = 5) const uint subgroup_size = 32;
105shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
106
107B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
108{
109    const uint row_i = blockCoords[0];
110
111    const u16vec4 row_idx = row_ids[row_i];
112    B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
113
114    return ret;
115}
116
117D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
118{
119    uint dr = ir * BM + r;
120    uint dc = ic * BN + c;
121
122    if (dr < p.M && dc < _ne1) {
123        uint row_i = c;
124        const u16vec4 row_idx = row_ids[row_i];
125        data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
126    }
127    return elem;
128}
129
130void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
131    _ne1 = 0;
132    uint num_elements = p.nei1 * p.nei0;
133    uint nei0shift = findLSB(p.nei0);
134
135    uint ids[16];
136    uint iter = 0;
137
138    uint expert_count = data_expert_count[expert_idx];
139
140    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
141        // prefetch up to 16 elements
142        if (iter == 0) {
143            [[unroll]] for (uint k = 0; k < 16; ++k) {
144                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
145                bool in_range = i < num_elements;
146                uint ii1;
147                if (nei0_is_pow2) {
148                    ii1 = i >> nei0shift;
149                } else {
150                    ii1 = i / p.nei0;
151                }
152                uint ii0 = i - ii1 * p.nei0;
153                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
154            }
155        }
156        uint i = j + gl_LocalInvocationIndex;
157        bool in_range = i < num_elements;
158        uint ii1;
159        if (nei0_is_pow2) {
160            ii1 = i >> nei0shift;
161        } else {
162            ii1 = i / p.nei0;
163        }
164        uint ii0 = i - ii1 * p.nei0;
165        uint id = ids[iter++];
166        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
167
168        ballots_sh[gl_SubgroupID] = ballot;
169        barrier();
170
171        uint subgroup_base = 0;
172        uint total = 0;
173        for (uint k = 0; k < gl_NumSubgroups; ++k) {
174            if (k == gl_SubgroupID) {
175                subgroup_base = total;
176            }
177            total += subgroupBallotBitCount(ballots_sh[k]);
178        }
179        barrier();
180
181        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
182        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
183            row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
184        }
185        _ne1 += total;
186        iter &= 15;
187        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
188            break;
189        }
190    }
191    barrier();
192}
193#endif
194
195void main() {
196    const uint tid = gl_LocalInvocationIndex;
197    const uint ic = gl_WorkGroupID.y;
198
199#ifdef MUL_MAT_ID
200    const uint expert_idx = gl_GlobalInvocationID.z;
201    if (ic * BN >= data_expert_count[expert_idx]) {
202        return;
203    }
204    // initialize to row 0 so we don't need to bounds check
205    if (tid < BN) {
206        row_ids[tid] = u16vec4(0);
207    }
208#if !defined(NEEDS_INIT_IQ_SHMEM)
209    barrier();
210#endif
211#endif
212
213#ifdef NEEDS_INIT_IQ_SHMEM
214    init_iq_shmem(gl_WorkGroupSize);
215#endif
216
217#ifndef MUL_MAT_ID
218    const uint batch_idx = gl_GlobalInvocationID.z;
219
220    const uint i13 = batch_idx / p.ne12;
221    const uint i12 = batch_idx % p.ne12;
222
223    const uint i03 = i13 / p.broadcast3;
224    const uint i02 = i12 / p.broadcast2;
225
226    const uint batch_idx_a = i03 * p.ne02 + i02;
227#endif
228
229    const uint blocks_m = (p.M + BM - 1) / BM;
230    const uint ir = gl_WorkGroupID.x % blocks_m;
231    const uint ik = gl_WorkGroupID.x / blocks_m;
232
233#ifdef MUL_MAT_ID
234    if (bitCount(p.nei0) == 1) {
235        load_row_ids(expert_idx, true, ic);
236    } else {
237        load_row_ids(expert_idx, false, ic);
238    }
239
240    // Workgroup has no work
241    if (ic * BN >= _ne1) return;
242#endif
243
244#ifdef MUL_MAT_ID
245    uint start_k = 0;
246    const uint end_k = p.K;
247#else
248    uint start_k = ik * p.k_split;
249    const uint end_k = min(p.K, (ik + 1) * p.k_split);
250#endif
251
252#ifdef MUL_MAT_ID
253    uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
254    uint pos_b = 0;
255#else
256    uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
257    uint pos_b = batch_idx * p.batch_stride_b;
258    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
259#endif
260
261    uint stride_a = p.stride_a / QUANT_K;
262    uint stride_b = p.stride_b;
263
264    // Hint to the compiler that values are aligned (want 16B alignment).
265    // Quants are always block-aligned, no alignment needed.
266#if ALIGNED
267#if QUANT_K == 1
268    stride_a &= ~7;
269#endif
270    stride_b &= ~7;
271#endif
272
273    // Create layouts for both clamped and unclamped accesses
274    tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
275    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
276    tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
277    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
278    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
279
280#if QUANT_K > 1
281    tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
282    tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
283#endif
284
285    // Use end_k rather than p.K as the dimension because that's what
286    // we need to bound check against when using split_k.
287    // Bounds check B against padded_N, but bounds check D against N.
288    tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
289    tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
290    tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
291    tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
292    tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
293
294    tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
295
296    tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
297
298#if !defined(MUL_MAT_ID)
299
300    const uint START_ALIGN_K = 256;
301    // For Qi_K (block size 256), unroll whole 256 element tiles.
302    // For legacy quants (block size 32), unroll 8x.
303    const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
304    const uint unroll_count = UNROLL_K / BK;
305
306    // Detect a fast path where all loads are entirely in bounds and no clamping is required
307    if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
308#if QUANT_K == 1
309        (stride_a % 8) == 0 &&
310#endif
311        (stride_b % 8) == 0) {
312        // Hint to the compiler that values are aligned (want 16B alignment)
313        start_k &= ~(START_ALIGN_K-1);
314        stride_b &= ~7;
315#if QUANT_K == 1
316        stride_a &= ~7;
317#endif
318
319        tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
320        tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
321
322        uint k_iters = (end_k - start_k) / UNROLL_K;
323        uint block_k = start_k;
324
325        // fetch scale values for a tile of quants. These will be copied into shared memory.
326        // The fetches and stores are pipelined to hide the latency.
327        fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
328
329        if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
330            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
331            for (uint i = 0; i < k_iters; ++i) {
332
333                store_scales(tid);
334                if (block_k + UNROLL_K < end_k) {
335                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
336                }
337
338                // Manually partial unroll
339                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
340                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
341                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
342
343                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
344                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
345
346                    sum = coopMatMulAdd(mat_a, mat_b, sum);
347                    block_k += BK;
348                }
349            }
350            // Do any remaining iterations that were not unrolled
351            if (block_k < end_k) {
352                store_scales(tid);
353            }
354            while (block_k < end_k) {
355                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
356                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
357
358                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
359                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
360
361                sum = coopMatMulAdd(mat_a, mat_b, sum);
362                block_k += BK;
363            }
364#if defined(ACC_TYPE_MAX)
365            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
366#endif
367
368            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
369
370            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
371            return;
372        } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
373            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
374            for (uint i = 0; i < k_iters; ++i) {
375
376                store_scales(tid);
377                if (block_k + UNROLL_K < end_k) {
378                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
379                }
380
381                // Manually partial unroll
382                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
383                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
384                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
385
386                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
387                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
388
389                    sum = coopMatMulAdd(mat_a, mat_b, sum);
390                    block_k += BK;
391                }
392            }
393            // Do any remaining iterations that were not unrolled
394            if (block_k < end_k) {
395                store_scales(tid);
396            }
397            while (block_k < end_k) {
398                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
399                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
400
401                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
402                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
403
404                sum = coopMatMulAdd(mat_a, mat_b, sum);
405                block_k += BK;
406            }
407#if defined(ACC_TYPE_MAX)
408            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
409#endif
410
411            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
412
413            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
414            return;
415        } else {
416            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
417
418            for (uint i = 0; i < k_iters; ++i) {
419
420                store_scales(tid);
421                if (block_k + UNROLL_K < end_k) {
422                    fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
423                }
424
425                // Manually partial unroll
426                [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
427                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
428                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
429
430                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
431                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
432
433                    sum = coopMatMulAdd(mat_a, mat_b, sum);
434                    block_k += BK;
435                }
436            }
437            // Do any remaining iterations that were not unrolled
438            if (block_k < end_k) {
439                store_scales(tid);
440            }
441            while (block_k < end_k) {
442                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
443                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
444
445                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
446                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
447
448                sum = coopMatMulAdd(mat_a, mat_b, sum);
449                block_k += BK;
450            }
451#if defined(ACC_TYPE_MAX)
452            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
453#endif
454
455            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
456
457            coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
458            return;
459        }
460    } else
461#endif // !defined(MUL_MAT_ID)
462    {
463        tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
464
465        tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
466
467        tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
468
469        tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
470
471        uint k_iters = (end_k - start_k + BK - 1) / BK;
472
473        fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
474        store_scales(tid);
475
476#ifdef MUL_MAT_ID
477        if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
478            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
479            sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
480
481            [[dont_unroll]]
482            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
483
484                if ((block_k % QUANT_K) == 0) {
485                    store_scales(tid);
486                }
487                if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
488                    fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
489                }
490
491                if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
492                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
493                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
494
495                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
496                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
497
498                    sum = coopMatMulAdd(mat_a, mat_b, sum);
499                } else {
500                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
501                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
502
503                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
504                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
505
506                    sum = coopMatMulAdd(mat_a, mat_b, sum);
507                }
508            }
509#if defined(ACC_TYPE_MAX)
510            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
511#endif
512
513            // Convert from ACC_TYPE to D_TYPE
514            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
515            mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
516
517            // Call callback to store each element, remapping row through shared memory
518            coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
519            return;
520        }
521        if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
522            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
523            sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
524
525            [[dont_unroll]]
526            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
527
528                if ((block_k % QUANT_K) == 0) {
529                    store_scales(tid);
530                }
531                if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
532                    fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
533                }
534
535                if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
536                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
537                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
538
539                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
540                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
541
542                    sum = coopMatMulAdd(mat_a, mat_b, sum);
543                } else {
544                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
545                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
546
547                    coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
548                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
549
550                    sum = coopMatMulAdd(mat_a, mat_b, sum);
551                }
552            }
553#if defined(ACC_TYPE_MAX)
554            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
555#endif
556
557            // Convert from ACC_TYPE to D_TYPE
558            coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
559            mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
560
561            // Call callback to store each element, remapping row through shared memory
562            coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
563            return;
564        }
565#endif
566        coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
567        sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
568
569        [[dont_unroll]]
570        for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
571
572            if ((block_k % QUANT_K) == 0) {
573                store_scales(tid);
574            }
575            if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
576                fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
577            }
578
579            if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
580                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
581                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
582
583                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
584#ifdef MUL_MAT_ID
585                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
586#else
587                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
588#endif
589
590                sum = coopMatMulAdd(mat_a, mat_b, sum);
591            } else {
592                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
593                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
594
595                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
596#ifdef MUL_MAT_ID
597                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
598#else
599                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
600#endif
601
602                sum = coopMatMulAdd(mat_a, mat_b, sum);
603            }
604        }
605#if defined(ACC_TYPE_MAX)
606        [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
607#endif
608
609        // Convert from ACC_TYPE to D_TYPE
610        coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
611        mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
612
613#ifdef MUL_MAT_ID
614        // Call callback to store each element, remapping row through shared memory
615        coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
616#else
617        coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
618#endif
619    }
620}