summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp24
1 files changed, 24 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
new file mode 100644
index 0000000..35ec726
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
@@ -0,0 +1,24 @@
+#version 450
+
+#include "types.glsl"
+#include "generic_unary_head.glsl"
+
+const uint num_threads = 128;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+ uint idx = get_idx();
+
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 4;
+
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+
+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
+ idx += num_threads;
+ }
+}