summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp156
1 files changed, 156 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp
new file mode 100644
index 0000000..32628c6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp
@@ -0,0 +1,156 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
+
+#define FLOAT_TYPE float
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#include "mul_mat_vec_iface.glsl"
+
+layout(constant_id = 0) const int BLOCK_SIZE = 32;
+// gqa_ratio is in the range [1,8]
+layout(constant_id = 1) const uint gqa_ratio = 1;
+
+layout (push_constant) uniform parameter
+{
+ uint ncols_x;
+ uint nrows_x;
+ uint nchannels_x;
+ uint nchannels_y;
+ uint b_offset;
+ uint d_offset;
+ uint fusion_flags;
+} p;
+
+#if !USE_SUBGROUP_ADD
+shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
+#endif
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint row_x = gl_GlobalInvocationID.y;
+
+ uint channel, channel_x;
+
+ // When gqa_ratio > 1, each invocation does multiple rows.
+ // The row in the A matrix is starting from channel / gqa_ratio and the
+ // rows in the B matrix are [channel, channel+gqa_ratio).
+ // When gpa_ratio is 1, each invocation does one row.
+ if (gqa_ratio > 1) {
+ channel_x = gl_GlobalInvocationID.z;
+ channel = channel_x * gqa_ratio;
+ } else {
+ channel = gl_GlobalInvocationID.z;
+ channel_x = channel / (p.nchannels_y / p.nchannels_x);;
+ }
+
+ const uint nrows_y = p.ncols_x;
+ const uint nrows_dst = p.nrows_x;
+ const uint row_dst = row_x;
+
+ FLOAT_TYPE temp[8];
+ [[unroll]] for (uint i = 0; i < 8; ++i) {
+ temp[i] = FLOAT_TYPE(0.0f);
+ }
+
+ // Detect alignment for vector loads
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
+
+ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
+
+ // Use vec4 loads if aligned
+ if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
+
+ uint col_x = col_x0 + 4*tid;
+ const uint row_y = col_x;
+
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
+
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
+
+ vec4 bv4 = data_b_v4[iy / 4];
+ temp[c] += dot(av4, bv4);
+ }
+
+ col_x0 += 3*BLOCK_SIZE;
+ } else {
+ const uint col_x = col_x0 + tid;
+
+ if (col_x >= p.ncols_x) {
+ break;
+ }
+
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+ const uint row_y = col_x;
+
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // y is not transposed but permuted
+ const uint iy = (channel + c)*nrows_y + row_y;
+
+ temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
+ }
+ }
+ }
+
+#if USE_SUBGROUP_ADD
+ // reduce vec4 at a time
+ vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
+ t = subgroupAdd(t);
+ temp[0] = t[0];
+ temp[1] = t[1];
+ temp[2] = t[2];
+ temp[3] = t[3];
+ if (gqa_ratio > 4) {
+ t = vec4(temp[4], temp[5], temp[6], temp[7]);
+ t = subgroupAdd(t);
+ temp[4] = t[0];
+ temp[5] = t[1];
+ temp[6] = t[2];
+ temp[7] = t[3];
+ }
+#else
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ tmp[c][tid] = temp[c];
+ }
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] += tmp[c][tid + s];
+ tmp[c][tid] = temp[c];
+ }
+ }
+ barrier();
+ }
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ temp[c] = tmp[c][tid];
+ }
+#endif
+
+ if (tid == 0) {
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
+ // dst is not transposed and not permuted
+ const uint idst = (channel + c)*nrows_dst + row_dst;
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ temp[c] += FLOAT_TYPE(data_fuse0[idst]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+ temp[c] += FLOAT_TYPE(data_fuse1[idst]);
+ }
+ data_d[idst] = temp[c];
+ }
+ }
+}