1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : require
  4#extension GL_KHR_shader_subgroup_basic : enable
  5#if USE_SUBGROUP_ADD
  6#extension GL_KHR_shader_subgroup_arithmetic : enable
  7#endif
  8
  9#include "types.glsl"
 10
 11layout(constant_id = 0) const uint D_STATE = 128;
 12layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
 13
 14const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
 15
 16layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 17
 18layout(binding = 0) readonly buffer Src0 { float s0[]; };
 19layout(binding = 1) readonly buffer Src1 { float x[]; };
 20layout(binding = 2) readonly buffer Src2 { float dt[]; };
 21layout(binding = 3) readonly buffer Src3 { float A[]; };
 22layout(binding = 4) readonly buffer Src4 { float B[]; };
 23layout(binding = 5) readonly buffer Src5 { float C[]; };
 24layout(binding = 6) readonly buffer Src6 { int ids[]; };
 25layout(binding = 7) buffer Dst { float d[]; };
 26
 27layout(push_constant) uniform PushConstants {
 28    uint nb02; uint nb03; uint nb12; uint nb13;
 29    uint nb21; uint nb22; uint nb31;
 30    uint nb42; uint nb43; uint nb52; uint nb53;
 31    uint s_off;
 32    uint n_head;
 33    uint d_head;
 34    uint n_group;
 35    uint n_tok;
 36};
 37
 38float softplus(float x) {
 39    if (x <= 20.0) {
 40        return log(1.0 + exp(x));
 41    } else {
 42        return x;
 43    }
 44}
 45
 46#if !USE_SUBGROUP_ADD
 47shared float temp[D_STATE];
 48#endif
 49
 50void main() {
 51    const uint subgroup = gl_SubgroupID;
 52    const uint lane     = gl_SubgroupInvocationID;
 53    const uint tid      = gl_SubgroupID * SUBGROUP_SIZE + lane;
 54    const uint subgroup_idx = gl_WorkGroupID.x  * c_factor + subgroup;
 55
 56    const uint head_idx =  subgroup_idx / d_head;
 57    const uint head_off = (subgroup_idx % d_head) * 4;
 58    const uint seq_idx  = gl_WorkGroupID.y;
 59
 60    const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
 61    const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
 62    const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
 63    const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
 64    const uint A_base_idx = (head_idx * nb31) / 4;
 65    const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
 66    const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
 67    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
 68    const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
 69
 70    const uint stride_x = nb12 / 4;
 71    const uint stride_dt = nb21 / 4;
 72    const uint stride_B = nb42 / 4;
 73    const uint stride_C = nb52 / 4;
 74    const uint stride_y = n_head * d_head;
 75
 76    float state[c_factor];
 77
 78    [[unroll]] for (uint j = 0; j < c_factor; j++) {
 79        state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
 80    }
 81
 82    float a = A[A_base_idx];
 83
 84    for (uint i = 0; i < n_tok; i++) {
 85        float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
 86
 87        float state_sum = 0.0f;
 88
 89        const float dA   = exp(dt_soft_plus * a);
 90        const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
 91        [[unroll]] for (uint j = 0; j < c_factor; j++) {
 92            float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
 93            float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
 94            state[j] = (state[j] * dA) + (B_val * x_dt);
 95            state_sum += state[j] * C_val;
 96        }
 97
 98#if USE_SUBGROUP_ADD
 99        state_sum = subgroupAdd(state_sum);
100#else
101        temp[tid] = state_sum;
102        barrier();
103        [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
104            if (lane < s) {
105                temp[tid] += temp[tid + s];
106            }
107            barrier();
108        }
109        // get the value from lane 0
110        state_sum = temp[subgroup * SUBGROUP_SIZE];
111        barrier();
112#endif
113
114        if (lane == 0) {
115            d[y_base_idx + i * stride_y] = state_sum;
116        }
117    }
118
119    // write back the state
120    [[unroll]]
121    for (int j = 0; j < c_factor; j++) {
122        d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
123    }
124}