summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp69
1 files changed, 69 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
new file mode 100644
index 0000000..3bcfe69
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
@@ -0,0 +1,69 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+#if ADD_RMS
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
+
+#include "types.glsl"
+#include "generic_binary_head.glsl"
+
+const uint num_threads = 256;
+
+layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+#if ADD_RMS
+// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
+shared FLOAT_TYPE sumsh[num_threads];
+#endif
+
+void main() {
+ uint idx = get_idx();
+ uint orig_idx = idx;
+
+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+ const uint num_iter = 2;
+
+ FLOAT_TYPE sum_sq = 0;
+
+ [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+ if (idx >= p.ne) {
+ continue;
+ }
+ uint i00, i01, i02, i03;
+ get_indices(idx, i00, i01, i02, i03);
+
+ FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
+ sum_sq += sum*sum;
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
+
+ idx += num_threads;
+ }
+
+#if ADD_RMS
+ if (p.param3 != 0) {
+ // reduce the sum within each subgroup, then across subgroups
+ const uint NumSubgroups = num_threads / gl_SubgroupSize;
+ sum_sq = subgroupAdd(sum_sq);
+ if (gl_SubgroupInvocationID == 0) {
+ sumsh[gl_SubgroupID] = sum_sq;
+ }
+ barrier();
+ [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
+ if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
+ sum_sq += sumsh[gl_SubgroupID + s];
+ sumsh[gl_SubgroupID] = sum_sq;
+ }
+ barrier();
+ }
+
+ if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
+ partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
+ }
+ }
+#endif
+}