diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp new file mode 100644 index 0000000..8c92c1a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -0,0 +1,142 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint NUM_SUBGROUPS = 4; +layout (constant_id = 2) const uint Br = 32; +layout (constant_id = 3) const uint Bc = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float16_t data_a[];}; +layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +layout (push_constant) uniform parameter { + uint nem0; + uint nem1; + uint nem2; + uint nbm1; + uint nbm2; + uint nbm3; + uint nbd1; + uint nbd2; + uint nbd3; +}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + +shared float minsh[NUM_SUBGROUPS]; +shared float maxsh[NUM_SUBGROUPS]; + +// For each Br x Bc block of the mask (input) buffer, read all values and check +// if it's all -inf or all zero. Write out a two-bit code indicating which it is +// (or zero for neither). Each workgroup processes 16 tiles and writes out a +// 32-bit result mask. +// +// TODO: This is a lot of work per workgroup, might make sense to split this into +// more workgroups in the future. +void main() { + // Each workgroup handles a row + const uint tid = gl_LocalInvocationIndex; + const uint i0 = gl_WorkGroupID.x; + const uint i1 = gl_WorkGroupID.y; + const uint i2 = gl_WorkGroupID.z % nem2; + const uint i3 = gl_WorkGroupID.z / nem2; + + float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + + uint result = 0; + + // Fast path for fully in-bounds blocks where we can do f16vec4 loads + if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && + ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } else { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) { + if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) { + continue; + } + uint j0 = (i + tid) % Bc; + uint j1 = (i + tid) / Bc; + + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (j0 < nem0 && j1 < nem1) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } + + if (tid == 0) { + data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result; + } +} |
