diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp new file mode 100644 index 0000000..c741620 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -0,0 +1,124 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +#include "types.glsl" + +layout(constant_id = 0) const uint D_STATE = 128; +layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; + +const uint32_t c_factor = D_STATE / SUBGROUP_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float s0[]; }; +layout(binding = 1) readonly buffer Src1 { float x[]; }; +layout(binding = 2) readonly buffer Src2 { float dt[]; }; +layout(binding = 3) readonly buffer Src3 { float A[]; }; +layout(binding = 4) readonly buffer Src4 { float B[]; }; +layout(binding = 5) readonly buffer Src5 { float C[]; }; +layout(binding = 6) readonly buffer Src6 { int ids[]; }; +layout(binding = 7) buffer Dst { float d[]; }; + +layout(push_constant) uniform PushConstants { + uint nb02; uint nb03; uint nb12; uint nb13; + uint nb21; uint nb22; uint nb31; + uint nb42; uint nb43; uint nb52; uint nb53; + uint s_off; + uint n_head; + uint d_head; + uint n_group; + uint n_tok; +}; + +float softplus(float x) { + if (x <= 20.0) { + return log(1.0 + exp(x)); + } else { + return x; + } +} + +#if !USE_SUBGROUP_ADD +shared float temp[D_STATE]; +#endif + +void main() { + const uint subgroup = gl_SubgroupID; + const uint lane = gl_SubgroupInvocationID; + const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane; + const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup; + + const uint head_idx = subgroup_idx / d_head; + const uint head_off = (subgroup_idx % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; + + const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; + const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4; + const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; + const uint A_base_idx = (head_idx * nb31) / 4; + const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; + const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx; + const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + + const uint stride_x = nb12 / 4; + const uint stride_dt = nb21 / 4; + const uint stride_B = nb42 / 4; + const uint stride_C = nb52 / 4; + const uint stride_y = n_head * d_head; + + float state[c_factor]; + + [[unroll]] for (uint j = 0; j < c_factor; j++) { + state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane]; + } + + float a = A[A_base_idx]; + + for (uint i = 0; i < n_tok; i++) { + float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + + float state_sum = 0.0f; + + const float dA = exp(dt_soft_plus * a); + const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus; + [[unroll]] for (uint j = 0; j < c_factor; j++) { + float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane]; + float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; + } + +#if USE_SUBGROUP_ADD + state_sum = subgroupAdd(state_sum); +#else + temp[tid] = state_sum; + barrier(); + [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) { + if (lane < s) { + temp[tid] += temp[tid + s]; + } + barrier(); + } + // get the value from lane 0 + state_sum = temp[subgroup * SUBGROUP_SIZE]; + barrier(); +#endif + + if (lane == 0) { + d[y_base_idx + i * stride_y] = state_sum; + } + } + + // write back the state + [[unroll]] + for (int j = 0; j < c_factor; j++) { + d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j]; + } +} |
