From b333b06772c89d96aacb5490d6a219fba7c09cc6 Mon Sep 17 00:00:00 2001 From: Mitja Felicijan Date: Thu, 12 Feb 2026 20:57:17 +0100 Subject: Engage! --- .../vulkan-shaders/soft_max_large2.comp | 79 ++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp') diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 0000000..69524f5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} -- cgit v1.2.3