1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4
5layout(constant_id = 0) const uint BLOCK_SIZE = 32;
6
7layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
9layout (binding = 0) readonly buffer A {float data_a[];};
10layout (binding = 1) readonly buffer B {float data_s[];};
11layout (binding = 2) writeonly buffer D {float data_d[];};
12
13layout (push_constant) uniform parameter {
14 uint D;
15 uint ne1;
16 uint ne2;
17 uint ne3;
18 uint k_num;
19 uint sinks;
20} p;
21
22shared float tmpsh[BLOCK_SIZE];
23
24void main() {
25 // Each workgroup handles a row
26 const uint n = gl_WorkGroupID.x;
27 const uint tid = gl_LocalInvocationID.x;
28 const uint i2 = gl_WorkGroupID.z % p.ne2;
29 const uint i3 = gl_WorkGroupID.z / p.ne2;
30
31 uint D = p.D;
32 uint k_num = p.k_num;
33
34 uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
35 uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
36 uint lm_stride = p.ne1 * 2;
37
38 // Compute the max m value for the row
39 float m_max = -1.0/0.0;
40 for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
41 float m = data_a[m_offset + (k + tid) * lm_stride];
42 m_max = max(m_max, m);
43 }
44
45 // reduce across the workgroup
46 tmpsh[tid] = m_max;
47 barrier();
48 [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
49 if (tid < s) {
50 m_max = max(m_max, tmpsh[tid + s]);
51 tmpsh[tid] = m_max;
52 }
53 barrier();
54 }
55 m_max = tmpsh[0];
56
57 barrier();
58
59 // Compute L based on m_max
60 float L = 0;
61 for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
62 float l = data_a[l_offset + (k + tid) * lm_stride];
63 float m = data_a[m_offset + (k + tid) * lm_stride];
64 L += exp(m - m_max) * l;
65 }
66
67 // reduce across the workgroup
68 tmpsh[tid] = L;
69 barrier();
70 [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
71 if (tid < s) {
72 L += tmpsh[tid + s];
73 tmpsh[tid] = L;
74 }
75 barrier();
76 }
77 L = tmpsh[0];
78
79 float sink;
80 if (p.sinks != 0) {
81 sink = data_s[n];
82
83 float ms = 1.0f;
84 float vs = 1.0f;
85
86 if (sink > m_max) {
87 ms = exp(m_max - sink);
88 } else {
89 vs = exp(sink - m_max);
90 }
91
92 L = L*ms + vs;
93 }
94
95 L = (L == 0.0) ? 0.0 : 1.0 / L;
96
97 // D dimension is split across workgroups in the y dimension
98 uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
99 // Scale and sum the O contributions based on m_max and store the result to memory
100 if (d < D) {
101 float O = 0.0;
102 [[unroll]] for (uint k = 0; k < k_num; ++k) {
103 uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
104 float m = data_a[m_offset + k * lm_stride];
105 O += exp(m - m_max) * data_a[o_offset];
106 }
107 if (p.sinks != 0) {
108 if (sink > m_max) {
109 float ms = 1.0f;
110 ms = exp(m_max - sink);
111 O *= ms;
112 }
113 }
114 O *= L;
115
116 const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
117 O = clamp(O, -FLT_MAX, FLT_MAX);
118
119 data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
120 }
121}