summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp44
1 files changed, 44 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
new file mode 100644
index 0000000..d62696b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
@@ -0,0 +1,44 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#include "types.glsl"
+
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0) readonly buffer Src0 { float src0[]; };
+layout(binding = 1) readonly buffer Src1 { float src1[]; };
+layout(binding = 2) buffer Dst { float dst[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint nb01; uint nb02;
+ uint nb11;
+ uint dst_nb0; uint dst_nb1; uint dst_nb2;
+ uint nc; uint ncs; uint nr; uint n_t; uint n_s;
+};
+
+void main() {
+ const uint global_thread_id = gl_GlobalInvocationID.x;
+ const uint i2 = gl_WorkGroupID.y;
+ const uint i3 = gl_WorkGroupID.z;
+
+ if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
+ return;
+ }
+
+ const uint i1 = global_thread_id;
+ const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
+ const uint src1_base = i1 * (nb11 / 4);
+ const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
+
+ float sum = 0.0;
+ [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
+ const uint src0_idx = src0_base + i0;
+ const uint src1_idx = src1_base + i0;
+ sum += src0[src0_idx] * src1[src1_idx];
+ }
+
+ dst[dst_idx] = sum;
+}