#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_integer_dot_product : require #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif #if defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif #include "types.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; layout (binding = 4) readonly buffer Counts {int data_expert_count[];}; #endif layout (push_constant) uniform parameter { uint M; uint N; uint K; uint stride_a; uint stride_b; uint stride_d; uint batch_stride_a; uint batch_stride_b; uint batch_stride_d; #ifdef MUL_MAT_ID uint nei0; uint nei1; uint nbi1; uint ne11; #else uint k_split; uint ne02; uint ne12; uint broadcast2; uint broadcast3; #endif } p; layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; // layout (constant_id = 3) const uint BK = 32; layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 7) const uint TM = 4; layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; #define BK 32 #include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID #define BK_STEP 1 #else #ifndef BK_STEP #define BK_STEP 4 #endif #endif // Shared memory cache shared block_a_cache buf_a[BM * BK_STEP]; shared block_b_cache buf_b[BN * BK_STEP]; // Register cache block_a_cache cache_a[WMITER * TM]; block_b_cache cache_b; #define LOAD_VEC_A (4 * QUANT_R_MMQ) #define LOAD_VEC_B 16 #define NUM_WARPS (BLOCK_SIZE / WARP) #include "mul_mm_id_funcs.glsl" #include "mul_mmq_funcs.glsl" void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } #endif #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif #ifndef MUL_MAT_ID const uint batch_idx = gl_GlobalInvocationID.z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; const uint i03 = i13 / p.broadcast3; const uint i02 = i12 / p.broadcast2; const uint batch_idx_a = i03 * p.ne02 + i02; #endif const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS if (bitCount(p.nei0) == 1) { load_row_ids(expert_idx, true, ic); } else { load_row_ids(expert_idx, false, ic); } #else _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { if (_ne1 >= ic * BN) { row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); } _ne1++; } } } barrier(); #endif // Workgroup has no work if (ic * BN >= _ne1) return; #endif #ifdef MUL_MAT_ID const uint start_k = 0; const uint end_k = p.K; #else const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif uint pos_a_ib = #ifdef MUL_MAT_ID expert_idx * (p.batch_stride_a / BK) + #else batch_idx_a * (p.batch_stride_a / BK) + #endif (ir * BM * p.stride_a + start_k) / BK; #ifdef MUL_MAT_ID uint pos_b_ib = 0; #else uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif ACC_TYPE sums[WMITER * TM * WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { const uint buf_ib = loadc_a + l; const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; const uint iqs = loadr_a; [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { if (block + k_step * BK < end_k) { block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); } } } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { const uint buf_ib = loadc_b + l; #ifdef MUL_MAT_ID const u16vec2 row_idx = row_ids[buf_ib]; const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK; #else const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; #endif const uint iqs = loadr_b; [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs, block + k_step * BK < end_k); } } barrier(); pos_a_ib += BK_STEP; pos_b_ib += BK_STEP; for (uint k_step = 0; k_step < BK_STEP; k_step++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint reg_ib = wsir * TM + cr; const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; block_a_to_registers(reg_ib, k_step * BM + buf_ib); } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; block_b_to_registers(ib); [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint cache_a_idx = wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; sums[sums_idx] += mmq_dot_product(cache_a_idx); } } } } } barrier(); } const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; [[unroll]] for (uint cc = 0; cc < TN; cc++) { #ifdef MUL_MAT_ID const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; #ifdef MUL_MAT_ID if (dr_warp + cr < p.M) { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); } #endif // MUL_MAT_ID } } } } }