summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp83
1 files changed, 83 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
new file mode 100644
index 0000000..75e3c3b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
@@ -0,0 +1,83 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+shared FLOAT_TYPE last_sum;
+
+void main() {
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+ const uint i03_offset = i03 * p.ne01*p.ne02;
+ const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+ const uint i01 = row - i03_offset - i02*p.ne01;
+
+ const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+ const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+ uint subgroup_id = tid / SUBGROUP_SIZE;
+
+ if (tid == 0) {
+ last_sum = 0;
+ }
+
+ uint col = tid * ELEM_PER_THREAD;
+ uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
+ for (int i = 0; i < num_iter; ++i) {
+ FLOAT_TYPE v[ELEM_PER_THREAD];
+ FLOAT_TYPE thread_sum = 0;
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ if (col + j < p.n_cols) {
+ thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
+ }
+ v[j] = thread_sum;
+ }
+
+ thread_sum = subgroupExclusiveAdd(thread_sum);
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += thread_sum;
+ }
+ // Store the largest partial sum for each subgroup, then add the partials for all
+ // lower subgroups and the final partial sum from the previous iteration.
+ if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+ partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
+ }
+ barrier();
+ for (int s = 0; s < subgroup_id; ++s) {
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += partial[s];
+ }
+ }
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ v[j] += last_sum;
+ }
+ barrier();
+ if (tid == BLOCK_SIZE - 1) {
+ last_sum = v[ELEM_PER_THREAD - 1];
+ }
+ [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+ if (col + j < p.n_cols) {
+ data_d[dst_idx + col + j] = D_TYPE(v[j]);
+ }
+ }
+ col += BLOCK_SIZE * ELEM_PER_THREAD;
+ }
+}