summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp143
1 files changed, 143 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
new file mode 100644
index 0000000..6fe3e2d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
@@ -0,0 +1,143 @@
+#version 450
+
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_integer_dot_product : require
+
+#define MMQ
+#define B_TYPE block_q8_1_x4
+
+#include "mul_mat_vec_base.glsl"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
+#define K_PER_ITER 8
+#elif defined(DATA_A_QUANT_K)
+#define K_PER_ITER 16
+#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
+#define K_PER_ITER 32
+#else
+#error unimplemented
+#endif
+
+uint a_offset, b_offset, d_offset;
+
+int32_t cache_b_qs[K_PER_ITER / 4];
+vec2 cache_b_ds;
+
+#include "mul_mat_vecq_funcs.glsl"
+
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
+
+ // Preload data_b block
+ const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
+ const uint b_qs_idx = tid % (32 / K_PER_ITER);
+ const uint b_block_idx_outer = b_block_idx / 4;
+ const uint b_block_idx_inner = b_block_idx % 4;
+ cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
+
+#if QUANT_R == 2
+ // Assumes K_PER_ITER == 8
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
+#else
+#if K_PER_ITER == 8
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
+#elif K_PER_ITER == 16
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
+ cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
+ cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
+#elif K_PER_ITER == 32
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
+ cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
+ cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
+ cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
+ cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
+ cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
+ cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
+#else
+#error unimplemented
+#endif
+#endif
+
+ uint ibi = first_row*p.ncols;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
+ ibi += p.ncols;
+
+ temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
+ }
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ const uint tid = gl_LocalInvocationID.x;
+
+ get_offsets(a_offset, b_offset, d_offset);
+ a_offset *= QUANT_K / QUANT_K_Q8_1;
+ b_offset /= QUANT_K_Q8_1;
+
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ temp[j][n] = FLOAT_TYPE(0.0f);
+ }
+ }
+
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
+ num_iters++;
+ }
+ int unroll_count = 4;
+ uint unrolled_iters = num_iters & ~(unroll_count - 1);
+
+ uint i = 0;
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+ }
+
+ unroll_count = 2;
+ unrolled_iters = num_iters & ~(unroll_count - 1);
+
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+ }
+ while (i < num_iters) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}