diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp new file mode 100644 index 0000000..d62696b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -0,0 +1,44 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float src0[]; }; +layout(binding = 1) readonly buffer Src1 { float src1[]; }; +layout(binding = 2) buffer Dst { float dst[]; }; + +layout(push_constant) uniform PushConstants { + uint nb01; uint nb02; + uint nb11; + uint dst_nb0; uint dst_nb1; uint dst_nb2; + uint nc; uint ncs; uint nr; uint n_t; uint n_s; +}; + +void main() { + const uint global_thread_id = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y; + const uint i3 = gl_WorkGroupID.z; + + if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + return; + } + + const uint i1 = global_thread_id; + const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); + const uint src1_base = i1 * (nb11 / 4); + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; + + float sum = 0.0; + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + const uint src0_idx = src0_base + i0; + const uint src1_idx = src1_base + i0; + sum += src0[src0_idx] * src1[src1_idx]; + } + + dst[dst_idx] = sum; +} |
