summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp62
1 files changed, 62 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp
new file mode 100644
index 0000000..39c4663
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp
@@ -0,0 +1,62 @@
+#version 450
+
+#include "soft_max_large_common.glsl"
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint rowx = gl_WorkGroupID.y;
+ const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
+
+ const uint32_t i03 = rowx / (p.ne01 * p.ne02);
+ const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
+ const uint32_t i01 = rowx % p.ne01;
+
+ uint rowy_start = 0;
+ if (p.KY > 0) {
+ rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
+ }
+
+ if (rowx >= p.nrows_x) {
+ return;
+ }
+
+ float slope = get_slope(rowx);
+
+ // Find max
+ FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
+
+ [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
+ const uint col = col0 + tid;
+
+ FLOAT_TYPE a = FLOAT_TYPE(0);
+ if (col < p.KX) {
+ a = data_a[rowx * p.KX + col];
+ }
+
+ FLOAT_TYPE b = FLOAT_TYPE(0);
+ if (p.KY > 0 && col < p.KX) {
+ b = data_b[rowy_start + col];
+ }
+
+ FLOAT_TYPE v = a * p.scale + slope * b;
+
+ if (col < p.KX) {
+ max_val = max(max_val, v);
+ }
+ }
+
+ // reduce across the workgroup
+ vals[tid] = max_val;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ vals[tid] = max(vals[tid], vals[tid + s]);
+ }
+ barrier();
+ }
+
+ if (tid == 0) {
+ max_val = vals[0];
+ data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
+ }
+}