summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp124
1 files changed, 124 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
new file mode 100644
index 0000000..c741620
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
@@ -0,0 +1,124 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
+#if USE_SUBGROUP_ADD
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#endif
+
+#include "types.glsl"
+
+layout(constant_id = 0) const uint D_STATE = 128;
+layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0) readonly buffer Src0 { float s0[]; };
+layout(binding = 1) readonly buffer Src1 { float x[]; };
+layout(binding = 2) readonly buffer Src2 { float dt[]; };
+layout(binding = 3) readonly buffer Src3 { float A[]; };
+layout(binding = 4) readonly buffer Src4 { float B[]; };
+layout(binding = 5) readonly buffer Src5 { float C[]; };
+layout(binding = 6) readonly buffer Src6 { int ids[]; };
+layout(binding = 7) buffer Dst { float d[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint nb02; uint nb03; uint nb12; uint nb13;
+ uint nb21; uint nb22; uint nb31;
+ uint nb42; uint nb43; uint nb52; uint nb53;
+ uint s_off;
+ uint n_head;
+ uint d_head;
+ uint n_group;
+ uint n_tok;
+};
+
+float softplus(float x) {
+ if (x <= 20.0) {
+ return log(1.0 + exp(x));
+ } else {
+ return x;
+ }
+}
+
+#if !USE_SUBGROUP_ADD
+shared float temp[D_STATE];
+#endif
+
+void main() {
+ const uint subgroup = gl_SubgroupID;
+ const uint lane = gl_SubgroupInvocationID;
+ const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
+ const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
+
+ const uint head_idx = subgroup_idx / d_head;
+ const uint head_off = (subgroup_idx % d_head) * 4;
+ const uint seq_idx = gl_WorkGroupID.y;
+
+ const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
+ const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
+ const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
+ const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
+ const uint A_base_idx = (head_idx * nb31) / 4;
+ const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
+ const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
+ const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
+ const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
+
+ const uint stride_x = nb12 / 4;
+ const uint stride_dt = nb21 / 4;
+ const uint stride_B = nb42 / 4;
+ const uint stride_C = nb52 / 4;
+ const uint stride_y = n_head * d_head;
+
+ float state[c_factor];
+
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
+ state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
+ }
+
+ float a = A[A_base_idx];
+
+ for (uint i = 0; i < n_tok; i++) {
+ float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
+
+ float state_sum = 0.0f;
+
+ const float dA = exp(dt_soft_plus * a);
+ const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
+ float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
+ float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
+ state[j] = (state[j] * dA) + (B_val * x_dt);
+ state_sum += state[j] * C_val;
+ }
+
+#if USE_SUBGROUP_ADD
+ state_sum = subgroupAdd(state_sum);
+#else
+ temp[tid] = state_sum;
+ barrier();
+ [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
+ if (lane < s) {
+ temp[tid] += temp[tid + s];
+ }
+ barrier();
+ }
+ // get the value from lane 0
+ state_sum = temp[subgroup * SUBGROUP_SIZE];
+ barrier();
+#endif
+
+ if (lane == 0) {
+ d[y_base_idx + i * stride_y] = state_sum;
+ }
+ }
+
+ // write back the state
+ [[unroll]]
+ for (int j = 0; j < c_factor; j++) {
+ d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
+ }
+}