summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp309
1 files changed, 309 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
new file mode 100644
index 0000000..335d7f6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -0,0 +1,309 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+
+#extension GL_EXT_integer_dot_product : require
+
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#endif
+
+#if defined(MUL_MAT_ID_USE_SUBGROUPS)
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#endif
+
+#ifdef MUL_MAT_ID
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#endif
+
+#include "types.glsl"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
+#endif
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+} p;
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 64;
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+// layout (constant_id = 3) const uint BK = 32;
+layout (constant_id = 4) const uint WM = 32;
+layout (constant_id = 5) const uint WN = 32;
+layout (constant_id = 6) const uint WMITER = 2;
+layout (constant_id = 7) const uint TM = 4;
+layout (constant_id = 8) const uint TN = 2;
+layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
+layout (constant_id = 10) const uint WARP = 32;
+
+#define BK 32
+
+#include "mul_mmq_shmem_types.glsl"
+
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
+#ifndef BK_STEP
+#define BK_STEP 4
+#endif
+#endif
+
+// Shared memory cache
+shared block_a_cache buf_a[BM * BK_STEP];
+shared block_b_cache buf_b[BN * BK_STEP];
+// Register cache
+block_a_cache cache_a[WMITER * TM];
+block_b_cache cache_b;
+
+#define LOAD_VEC_A (4 * QUANT_R_MMQ)
+#define LOAD_VEC_B 16
+
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
+#include "mul_mm_id_funcs.glsl"
+#include "mul_mmq_funcs.glsl"
+
+void main() {
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+ if (ic * BN >= data_expert_count[expert_idx]) {
+ return;
+ }
+#endif
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+#ifndef MUL_MAT_ID
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+
+ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
+ const uint WSUBM = WM / WMITER;
+ const uint WSUBN = WN / WNITER;
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
+
+ const uint tiw = gl_LocalInvocationID.x % WARP;
+
+ const uint tiwr = tiw % (WSUBM / TM);
+ const uint tiwc = tiw / (WSUBM / TM);
+
+ const uint warp_r = warp_i % (BM / WM);
+ const uint warp_c = warp_i / (BM / WM);
+
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+
+ const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
+ const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
+
+#ifdef MUL_MAT_ID
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true, ic);
+ } else {
+ load_row_ids(expert_idx, false, ic);
+ }
+#else
+ _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
+ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
+ _ne1++;
+ }
+ }
+ }
+
+ barrier();
+#endif
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ const uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ const uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+ uint pos_a_ib =
+#ifdef MUL_MAT_ID
+ expert_idx * (p.batch_stride_a / BK) +
+#else
+ batch_idx_a * (p.batch_stride_a / BK) +
+#endif
+ (ir * BM * p.stride_a + start_k) / BK;
+#ifdef MUL_MAT_ID
+ uint pos_b_ib = 0;
+#else
+ uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
+#endif
+
+ ACC_TYPE sums[WMITER * TM * WNITER * TN];
+
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+ sums[i] = ACC_TYPE(0.0f);
+ }
+
+ for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
+ [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
+ const uint buf_ib = loadc_a + l;
+ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
+ const uint iqs = loadr_a;
+
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ if (block + k_step * BK < end_k) {
+ block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
+ }
+ }
+ }
+ [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+ const uint buf_ib = loadc_b + l;
+
+#ifdef MUL_MAT_ID
+ const u16vec2 row_idx = row_ids[buf_ib];
+ const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
+#else
+ const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
+#endif
+ const uint iqs = loadr_b;
+
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs, block + k_step * BK < end_k);
+ }
+ }
+
+ barrier();
+
+ pos_a_ib += BK_STEP;
+ pos_b_ib += BK_STEP;
+
+ for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ // Load from shared into cache
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint reg_ib = wsir * TM + cr;
+ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
+
+ block_a_to_registers(reg_ib, k_step * BM + buf_ib);
+ }
+ }
+
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+ const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+ block_b_to_registers(ib);
+
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint cache_a_idx = wsir * TM + cr;
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+
+ sums[sums_idx] += mmq_dot_product(cache_a_idx);
+ }
+ }
+ }
+ }
+ }
+
+ barrier();
+ }
+
+ const uint dr = ir * BM + warp_r * WM;
+ const uint dc = ic * BN + warp_c * WN;
+
+#ifndef MUL_MAT_ID
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+
+ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
+ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+#ifdef MUL_MAT_ID
+ const uint row_i = dc_warp + cc;
+ if (row_i >= _ne1) break;
+
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
+#endif // MUL_MAT_ID
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
+#ifdef MUL_MAT_ID
+ if (dr_warp + cr < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+ }
+#else
+ if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+ }
+#endif // MUL_MAT_ID
+ }
+ }
+ }
+ }
+}