summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp142
1 files changed, 142 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
new file mode 100644
index 0000000..8c92c1a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
@@ -0,0 +1,142 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
+layout (constant_id = 2) const uint Br = 32;
+layout (constant_id = 3) const uint Bc = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float16_t data_a[];};
+layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint nem0;
+ uint nem1;
+ uint nem2;
+ uint nbm1;
+ uint nbm2;
+ uint nbm3;
+ uint nbd1;
+ uint nbd2;
+ uint nbd3;
+};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
+shared float minsh[NUM_SUBGROUPS];
+shared float maxsh[NUM_SUBGROUPS];
+
+// For each Br x Bc block of the mask (input) buffer, read all values and check
+// if it's all -inf or all zero. Write out a two-bit code indicating which it is
+// (or zero for neither). Each workgroup processes 16 tiles and writes out a
+// 32-bit result mask.
+//
+// TODO: This is a lot of work per workgroup, might make sense to split this into
+// more workgroups in the future.
+void main() {
+ // Each workgroup handles a row
+ const uint tid = gl_LocalInvocationIndex;
+ const uint i0 = gl_WorkGroupID.x;
+ const uint i1 = gl_WorkGroupID.y;
+ const uint i2 = gl_WorkGroupID.z % nem2;
+ const uint i3 = gl_WorkGroupID.z / nem2;
+
+ float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+ uint result = 0;
+
+ // Fast path for fully in-bounds blocks where we can do f16vec4 loads
+ if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
+ ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
+ [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+ float min_v = FLT_MAX_OVER_2;
+ float max_v = -FLT_MAX_OVER_2;
+ [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+ uint j0 = (i + tid) % (Bc / 4);
+ uint j1 = (i + tid) / (Bc / 4);
+
+ j0 *= 4;
+ j0 += (i0 * 16 + block_x) * Bc;
+ j1 += i1 * Br;
+
+ vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+ [[unroll]] for (int c = 0; c < 4; ++c) {
+ min_v = min(min_v, f[c]);
+ max_v = max(max_v, f[c]);
+ }
+ }
+ min_v = subgroupMin(min_v);
+ max_v = subgroupMax(max_v);
+ if (gl_SubgroupInvocationID == 0) {
+ minsh[gl_SubgroupID] = min_v;
+ maxsh[gl_SubgroupID] = max_v;
+ }
+ barrier();
+ if (tid == 0) {
+ [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+ min_v = min(min_v, minsh[i]);
+ max_v = max(max_v, maxsh[i]);
+ }
+ if (max_v <= -FLT_MAX_OVER_2) {
+ result |= 1 << (2*block_x);
+ }
+ if (min_v == 0.0f && max_v == 0.0f) {
+ result |= 2 << (2*block_x);
+ }
+ }
+ barrier();
+ }
+ } else {
+ [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+ float min_v = FLT_MAX_OVER_2;
+ float max_v = -FLT_MAX_OVER_2;
+ [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
+ if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
+ continue;
+ }
+ uint j0 = (i + tid) % Bc;
+ uint j1 = (i + tid) / Bc;
+
+ j0 += (i0 * 16 + block_x) * Bc;
+ j1 += i1 * Br;
+
+ if (j0 < nem0 && j1 < nem1) {
+ float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+ min_v = min(min_v, f);
+ max_v = max(max_v, f);
+ }
+ }
+ min_v = subgroupMin(min_v);
+ max_v = subgroupMax(max_v);
+ if (gl_SubgroupInvocationID == 0) {
+ minsh[gl_SubgroupID] = min_v;
+ maxsh[gl_SubgroupID] = max_v;
+ }
+ barrier();
+ if (tid == 0) {
+ [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+ min_v = min(min_v, minsh[i]);
+ max_v = max(max_v, maxsh[i]);
+ }
+ if (max_v <= -FLT_MAX_OVER_2) {
+ result |= 1 << (2*block_x);
+ }
+ if (min_v == 0.0f && max_v == 0.0f) {
+ result |= 2 << (2*block_x);
+ }
+ }
+ barrier();
+ }
+ }
+
+ if (tid == 0) {
+ data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
+ }
+}