summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp620
1 files changed, 620 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
new file mode 100644
index 0000000..b6614d2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -0,0 +1,620 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_buffer_reference : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_vote : enable
+#ifdef DATA_A_BF16
+#extension GL_EXT_bfloat16 : enable
+#endif
+
+#include "types.glsl"
+#include "utils.glsl"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#define IS_MUL_MM2 1
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 256;
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
+
+layout (constant_id = 4) const bool enable_smaller_matrices = false;
+const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
+const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+ // N dimension for the B matrix can be >= p.N
+ uint padded_N;
+} p;
+
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#if QUANT_K > 1
+#define DECODEFUNCA , dequantFuncA
+
+#include "dequant_funcs_cm2.glsl"
+
+#else
+#define DECODEFUNCA
+#endif
+
+#if !defined(fetch_scales)
+#define fetch_scales(a, b, c, d, e, f)
+#endif
+#if !defined(store_scales)
+#define store_scales(a)
+#endif
+
+#if defined(DATA_A_BF16)
+#define MAT_TYPE bfloat16_t
+#else
+#define MAT_TYPE FLOAT_TYPE
+#endif
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
+
+shared u16vec4 row_ids[BN];
+
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
+ B_TYPE b[];
+};
+
+uint _ne1;
+layout (constant_id = 5) const uint subgroup_size = 32;
+shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
+
+B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const uint row_i = blockCoords[0];
+
+ const u16vec4 row_idx = row_ids[row_i];
+ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
+
+ return ret;
+}
+
+D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
+{
+ uint dr = ir * BM + r;
+ uint dc = ic * BN + c;
+
+ if (dr < p.M && dc < _ne1) {
+ uint row_i = c;
+ const u16vec4 row_idx = row_ids[row_i];
+ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
+ }
+ return elem;
+}
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ uint expert_count = data_expert_count[expert_idx];
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+ }
+ _ne1 += total;
+ iter &= 15;
+ if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
+ break;
+ }
+ }
+ barrier();
+}
+#endif
+
+void main() {
+ const uint tid = gl_LocalInvocationIndex;
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+ if (ic * BN >= data_expert_count[expert_idx]) {
+ return;
+ }
+ // initialize to row 0 so we don't need to bounds check
+ if (tid < BN) {
+ row_ids[tid] = u16vec4(0);
+ }
+#if !defined(NEEDS_INIT_IQ_SHMEM)
+ barrier();
+#endif
+#endif
+
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+#ifndef MUL_MAT_ID
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+
+#ifdef MUL_MAT_ID
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true, ic);
+ } else {
+ load_row_ids(expert_idx, false, ic);
+ }
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+#ifdef MUL_MAT_ID
+ uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
+ uint pos_b = 0;
+#else
+ uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
+ uint pos_b = batch_idx * p.batch_stride_b;
+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+ uint stride_a = p.stride_a / QUANT_K;
+ uint stride_b = p.stride_b;
+
+ // Hint to the compiler that values are aligned (want 16B alignment).
+ // Quants are always block-aligned, no alignment needed.
+#if ALIGNED
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+ stride_b &= ~7;
+#endif
+
+ // Create layouts for both clamped and unclamped accesses
+ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+
+#if QUANT_K > 1
+ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
+ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
+#endif
+
+ // Use end_k rather than p.K as the dimension because that's what
+ // we need to bound check against when using split_k.
+ // Bounds check B against padded_N, but bounds check D against N.
+ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
+ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
+ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
+ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
+
+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
+
+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
+
+#if !defined(MUL_MAT_ID)
+
+ const uint START_ALIGN_K = 256;
+ // For Qi_K (block size 256), unroll whole 256 element tiles.
+ // For legacy quants (block size 32), unroll 8x.
+ const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
+ const uint unroll_count = UNROLL_K / BK;
+
+ // Detect a fast path where all loads are entirely in bounds and no clamping is required
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
+#if QUANT_K == 1
+ (stride_a % 8) == 0 &&
+#endif
+ (stride_b % 8) == 0) {
+ // Hint to the compiler that values are aligned (want 16B alignment)
+ start_k &= ~(START_ALIGN_K-1);
+ stride_b &= ~7;
+#if QUANT_K == 1
+ stride_a &= ~7;
+#endif
+
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ uint k_iters = (end_k - start_k) / UNROLL_K;
+ uint block_k = start_k;
+
+ // fetch scale values for a tile of quants. These will be copied into shared memory.
+ // The fetches and stores are pipelined to hide the latency.
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
+
+ if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
+ return;
+ } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
+ return;
+ } else {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+
+ for (uint i = 0; i < k_iters; ++i) {
+
+ store_scales(tid);
+ if (block_k + UNROLL_K < end_k) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
+ }
+
+ // Manually partial unroll
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+ }
+ // Do any remaining iterations that were not unrolled
+ if (block_k < end_k) {
+ store_scales(tid);
+ }
+ while (block_k < end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ block_k += BK;
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
+
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
+ return;
+ }
+ } else
+#endif // !defined(MUL_MAT_ID)
+ {
+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
+
+ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
+
+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
+
+ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
+
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
+
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
+ store_scales(tid);
+
+#ifdef MUL_MAT_ID
+ if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+#endif
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+#ifdef MUL_MAT_ID
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+#else
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
+#endif
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
+
+#ifdef MUL_MAT_ID
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+#else
+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
+#endif
+ }
+}