summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp169
1 files changed, 169 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
new file mode 100644
index 0000000..2271be4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
@@ -0,0 +1,169 @@
+#version 450
+
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#include "mul_mat_vec_base.glsl"
+#include "dequant_funcs.glsl"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
+#define K_PER_ITER 8
+#else
+#define K_PER_ITER 2
+#endif
+
+
+uint a_offset, b_offset, d_offset, y_offset;
+
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
+{
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = col - col%QUANT_K; // y block start index
+
+#if K_PER_ITER == 8
+#if QUANT_R == 2
+ const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+ const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
+ const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
+ const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
+#else
+ const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+ const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
+#endif
+#else
+ // Check if the second of the pair of elements is OOB, and don't fetch B or
+ // accumulate it. We still fetch a pair of elements for A, which is fine for
+ // quantized formats since they'll be within the same block. We should
+ // probably skip fetching the second element for F16/F32, but as of now we
+ // still do.
+ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
+
+ FLOAT_TYPE b0 = 0, b1 = 0;
+ b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
+ if (!OOB) {
+ b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
+ }
+#endif
+ uint ibi = first_row*p.ncols;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib = (ibi + col)/QUANT_K; // block index
+ ibi += p.ncols;
+
+#if K_PER_ITER == 8
+ vec4 v = dequantize4(ib, iqs, a_offset);
+ vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
+
+ const vec2 dm = get_dm(ib, a_offset);
+ if (dm.y != 0) { // quant has min component
+ v = v * dm.x + dm.y;
+ v2 = v2 * dm.x + dm.y;
+ }
+
+ // matrix multiplication
+ FLOAT_TYPE rowtmp = dot(bv0, v);
+ rowtmp += dot(bv1, v2);
+
+ if (dm.y == 0)
+ rowtmp *= dm.x;
+
+ temp[j][n] += rowtmp;
+#else
+ const vec2 v = dequantize(ib, iqs, a_offset);
+
+ // matrix multiplication
+ temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
+ if (!OOB) {
+ temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
+ }
+#endif
+ }
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ const uint tid = gl_LocalInvocationID.x;
+
+ get_offsets(a_offset, b_offset, d_offset);
+
+ y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
+ }
+
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
+ num_iters++;
+ }
+ int unroll_count = 4;
+ uint unrolled_iters = num_iters & ~(unroll_count - 1);
+
+#if K_PER_ITER == 2
+ // If the K dimension is odd, we need lastiter==true on the last iteration
+ // so OOB is computed correctly. Skip some unrolling to make that happen.
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
+ uint i = 0;
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
+ i++;
+ }
+ }
+
+ unroll_count = 2;
+ unrolled_iters = num_iters & ~(unroll_count - 1);
+
+#if K_PER_ITER == 2
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
+ i++;
+ }
+ }
+ while (i < num_iters) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
+ i++;
+ }
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+#ifdef NEEDS_INIT_IQ_SHMEM
+ init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}