1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4
  5layout(constant_id = 0) const uint BLOCK_SIZE = 32;
  6
  7layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  8
  9layout (binding = 0) readonly buffer A {float data_a[];};
 10layout (binding = 1) readonly buffer B {float data_s[];};
 11layout (binding = 2) writeonly buffer D {float data_d[];};
 12
 13layout (push_constant) uniform parameter {
 14    uint D;
 15    uint ne1;
 16    uint ne2;
 17    uint ne3;
 18    uint k_num;
 19    uint sinks;
 20} p;
 21
 22shared float tmpsh[BLOCK_SIZE];
 23
 24void main() {
 25    // Each workgroup handles a row
 26    const uint n = gl_WorkGroupID.x;
 27    const uint tid = gl_LocalInvocationID.x;
 28    const uint i2 = gl_WorkGroupID.z % p.ne2;
 29    const uint i3 = gl_WorkGroupID.z / p.ne2;
 30
 31    uint D = p.D;
 32    uint k_num = p.k_num;
 33
 34    uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
 35    uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
 36    uint lm_stride = p.ne1 * 2;
 37
 38    // Compute the max m value for the row
 39    float m_max = -1.0/0.0;
 40    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
 41        float m = data_a[m_offset + (k + tid) * lm_stride];
 42        m_max = max(m_max, m);
 43    }
 44
 45    // reduce across the workgroup
 46    tmpsh[tid] = m_max;
 47    barrier();
 48    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
 49        if (tid < s) {
 50            m_max = max(m_max, tmpsh[tid + s]);
 51            tmpsh[tid] = m_max;
 52        }
 53        barrier();
 54    }
 55    m_max = tmpsh[0];
 56
 57    barrier();
 58
 59    // Compute L based on m_max
 60    float L = 0;
 61    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
 62        float l = data_a[l_offset + (k + tid) * lm_stride];
 63        float m = data_a[m_offset + (k + tid) * lm_stride];
 64        L += exp(m - m_max) * l;
 65    }
 66
 67    // reduce across the workgroup
 68    tmpsh[tid] = L;
 69    barrier();
 70    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
 71        if (tid < s) {
 72            L += tmpsh[tid + s];
 73            tmpsh[tid] = L;
 74        }
 75        barrier();
 76    }
 77    L = tmpsh[0];
 78
 79    float sink;
 80    if (p.sinks != 0) {
 81        sink = data_s[n];
 82
 83        float ms = 1.0f;
 84        float vs = 1.0f;
 85
 86        if (sink > m_max) {
 87            ms = exp(m_max - sink);
 88        } else {
 89            vs = exp(sink - m_max);
 90        }
 91
 92        L = L*ms + vs;
 93    }
 94
 95    L = (L == 0.0) ? 0.0 : 1.0 / L;
 96
 97    // D dimension is split across workgroups in the y dimension
 98    uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
 99    // Scale and sum the O contributions based on m_max and store the result to memory
100    if (d < D) {
101        float O = 0.0;
102        [[unroll]] for (uint k = 0; k < k_num; ++k) {
103            uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
104            float m = data_a[m_offset + k * lm_stride];
105            O += exp(m - m_max) * data_a[o_offset];
106        }
107        if (p.sinks != 0) {
108            if (sink > m_max) {
109                float ms = 1.0f;
110                ms = exp(m_max - sink);
111                O *= ms;
112            }
113        }
114        O *= L;
115
116        const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
117        O = clamp(O, -FLT_MAX, FLT_MAX);
118
119        data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
120    }
121}