summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp121
1 files changed, 121 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
new file mode 100644
index 0000000..68917fc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -0,0 +1,121 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 1) readonly buffer B {float data_s[];};
+layout (binding = 2) writeonly buffer D {float data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint D;
+ uint ne1;
+ uint ne2;
+ uint ne3;
+ uint k_num;
+ uint sinks;
+} p;
+
+shared float tmpsh[BLOCK_SIZE];
+
+void main() {
+ // Each workgroup handles a row
+ const uint n = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+ const uint i2 = gl_WorkGroupID.z % p.ne2;
+ const uint i3 = gl_WorkGroupID.z / p.ne2;
+
+ uint D = p.D;
+ uint k_num = p.k_num;
+
+ uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
+ uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
+ uint lm_stride = p.ne1 * 2;
+
+ // Compute the max m value for the row
+ float m_max = -1.0/0.0;
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+ float m = data_a[m_offset + (k + tid) * lm_stride];
+ m_max = max(m_max, m);
+ }
+
+ // reduce across the workgroup
+ tmpsh[tid] = m_max;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ m_max = max(m_max, tmpsh[tid + s]);
+ tmpsh[tid] = m_max;
+ }
+ barrier();
+ }
+ m_max = tmpsh[0];
+
+ barrier();
+
+ // Compute L based on m_max
+ float L = 0;
+ for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+ float l = data_a[l_offset + (k + tid) * lm_stride];
+ float m = data_a[m_offset + (k + tid) * lm_stride];
+ L += exp(m - m_max) * l;
+ }
+
+ // reduce across the workgroup
+ tmpsh[tid] = L;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ L += tmpsh[tid + s];
+ tmpsh[tid] = L;
+ }
+ barrier();
+ }
+ L = tmpsh[0];
+
+ float sink;
+ if (p.sinks != 0) {
+ sink = data_s[n];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (sink > m_max) {
+ ms = exp(m_max - sink);
+ } else {
+ vs = exp(sink - m_max);
+ }
+
+ L = L*ms + vs;
+ }
+
+ L = (L == 0.0) ? 0.0 : 1.0 / L;
+
+ // D dimension is split across workgroups in the y dimension
+ uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
+ // Scale and sum the O contributions based on m_max and store the result to memory
+ if (d < D) {
+ float O = 0.0;
+ [[unroll]] for (uint k = 0; k < k_num; ++k) {
+ uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
+ float m = data_a[m_offset + k * lm_stride];
+ O += exp(m - m_max) * data_a[o_offset];
+ }
+ if (p.sinks != 0) {
+ if (sink > m_max) {
+ float ms = 1.0f;
+ ms = exp(m_max - sink);
+ O *= ms;
+ }
+ }
+ O *= L;
+
+ const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
+ O = clamp(O, -FLT_MAX, FLT_MAX);
+
+ data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
+ }
+}