summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp27
1 files changed, 27 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
new file mode 100644
index 0000000..02ef1ea
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
@@ -0,0 +1,27 @@
+#version 450
+
+#include "types.glsl"
+#include "generic_binary_head.glsl"
+
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ uint idx = get_idx();
+
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
+
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+ idx += num_threads;
+ }
+}