1#version 450
2
3#extension GL_EXT_control_flow_attributes : require
4#extension GL_KHR_shader_subgroup_basic : enable
5#if USE_SUBGROUP_ADD
6#extension GL_KHR_shader_subgroup_arithmetic : enable
7#endif
8
9#include "types.glsl"
10
11layout(constant_id = 0) const uint D_STATE = 128;
12layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
13
14const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
15
16layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
17
18layout(binding = 0) readonly buffer Src0 { float s0[]; };
19layout(binding = 1) readonly buffer Src1 { float x[]; };
20layout(binding = 2) readonly buffer Src2 { float dt[]; };
21layout(binding = 3) readonly buffer Src3 { float A[]; };
22layout(binding = 4) readonly buffer Src4 { float B[]; };
23layout(binding = 5) readonly buffer Src5 { float C[]; };
24layout(binding = 6) readonly buffer Src6 { int ids[]; };
25layout(binding = 7) buffer Dst { float d[]; };
26
27layout(push_constant) uniform PushConstants {
28 uint nb02; uint nb03; uint nb12; uint nb13;
29 uint nb21; uint nb22; uint nb31;
30 uint nb42; uint nb43; uint nb52; uint nb53;
31 uint s_off;
32 uint n_head;
33 uint d_head;
34 uint n_group;
35 uint n_tok;
36};
37
38float softplus(float x) {
39 if (x <= 20.0) {
40 return log(1.0 + exp(x));
41 } else {
42 return x;
43 }
44}
45
46#if !USE_SUBGROUP_ADD
47shared float temp[D_STATE];
48#endif
49
50void main() {
51 const uint subgroup = gl_SubgroupID;
52 const uint lane = gl_SubgroupInvocationID;
53 const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
54 const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
55
56 const uint head_idx = subgroup_idx / d_head;
57 const uint head_off = (subgroup_idx % d_head) * 4;
58 const uint seq_idx = gl_WorkGroupID.y;
59
60 const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
61 const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
62 const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
63 const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
64 const uint A_base_idx = (head_idx * nb31) / 4;
65 const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
66 const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
67 const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
68 const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
69
70 const uint stride_x = nb12 / 4;
71 const uint stride_dt = nb21 / 4;
72 const uint stride_B = nb42 / 4;
73 const uint stride_C = nb52 / 4;
74 const uint stride_y = n_head * d_head;
75
76 float state[c_factor];
77
78 [[unroll]] for (uint j = 0; j < c_factor; j++) {
79 state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
80 }
81
82 float a = A[A_base_idx];
83
84 for (uint i = 0; i < n_tok; i++) {
85 float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
86
87 float state_sum = 0.0f;
88
89 const float dA = exp(dt_soft_plus * a);
90 const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
91 [[unroll]] for (uint j = 0; j < c_factor; j++) {
92 float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
93 float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
94 state[j] = (state[j] * dA) + (B_val * x_dt);
95 state_sum += state[j] * C_val;
96 }
97
98#if USE_SUBGROUP_ADD
99 state_sum = subgroupAdd(state_sum);
100#else
101 temp[tid] = state_sum;
102 barrier();
103 [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
104 if (lane < s) {
105 temp[tid] += temp[tid + s];
106 }
107 barrier();
108 }
109 // get the value from lane 0
110 state_sum = temp[subgroup * SUBGROUP_SIZE];
111 barrier();
112#endif
113
114 if (lane == 0) {
115 d[y_base_idx + i * stride_y] = state_sum;
116 }
117 }
118
119 // write back the state
120 [[unroll]]
121 for (int j = 0; j < c_factor; j++) {
122 d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
123 }
124}