diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp new file mode 100644 index 0000000..75e3c3b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -0,0 +1,83 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout (constant_id = 2) const uint ELEM_PER_THREAD = 4; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; +shared FLOAT_TYPE last_sum; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + if (tid == 0) { + last_sum = 0; + } + + uint col = tid * ELEM_PER_THREAD; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD); + for (int i = 0; i < num_iter; ++i) { + FLOAT_TYPE v[ELEM_PER_THREAD]; + FLOAT_TYPE thread_sum = 0; + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]); + } + v[j] = thread_sum; + } + + thread_sum = subgroupExclusiveAdd(thread_sum); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += thread_sum; + } + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v[ELEM_PER_THREAD - 1]; + } + barrier(); + for (int s = 0; s < subgroup_id; ++s) { + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += partial[s]; + } + } + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += last_sum; + } + barrier(); + if (tid == BLOCK_SIZE - 1) { + last_sum = v[ELEM_PER_THREAD - 1]; + } + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + data_d[dst_idx + col + j] = D_TYPE(v[j]); + } + } + col += BLOCK_SIZE * ELEM_PER_THREAD; + } +} |
