diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 0000000..32628c6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,156 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +#define FLOAT_TYPE float + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#include "mul_mat_vec_iface.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; + uint fusion_flags; +} p; + +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } + } + barrier(); + } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif + + if (tid == 0) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[c] += FLOAT_TYPE(data_fuse0[idst]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + temp[c] += FLOAT_TYPE(data_fuse1[idst]); + } + data_d[idst] = temp[c]; + } + } +} |
