From b333b06772c89d96aacb5490d6a219fba7c09cc6 Mon Sep 17 00:00:00 2001 From: Mitja Felicijan Date: Thu, 12 Feb 2026 20:57:17 +0100 Subject: Engage! --- .../vulkan-shaders/rms_norm_partials.comp | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp') diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 0000000..4618b2c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} -- cgit v1.2.3