1
2layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3
4layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5layout (constant_id = 1) const uint32_t Br = 1;
6layout (constant_id = 2) const uint32_t Bc = 32;
7layout (constant_id = 3) const uint32_t HSK = 32;
8layout (constant_id = 4) const uint32_t HSV = 32;
9layout (constant_id = 5) const uint32_t Clamp = 0;
10layout (constant_id = 6) const uint32_t D_split = 16;
11layout (constant_id = 7) const uint32_t SubGroupSize = 32;
12layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
13layout (constant_id = 9) const uint32_t Flags = 0;
14
15const bool USE_MASK_OPT = (Flags & 1) != 0;
16const bool MASK_ENABLE = (Flags & 2) != 0;
17const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
18
19// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
20const uint32_t HSK_pad = (HSK + 15) & ~15;
21const uint32_t HSV_pad = (HSV + 15) & ~15;
22
23const bool KV_bounds_check = Clamp != 0;
24
25layout (push_constant) uniform parameter {
26 uint32_t N;
27 uint32_t KV;
28
29 uint32_t ne1;
30 uint32_t ne2;
31 uint32_t ne3;
32
33 uint32_t neq2;
34 uint32_t neq3;
35 uint32_t nek2;
36 uint32_t nek3;
37 uint32_t nev2;
38 uint32_t nev3;
39 uint32_t nem1;
40 uint32_t nem2;
41 uint32_t nem3;
42
43 uint32_t nb01;
44 uint32_t nb02;
45 uint32_t nb03;
46 uint32_t nb11;
47 uint32_t nb12;
48 uint32_t nb13;
49 uint32_t nb21;
50 uint32_t nb22;
51 uint32_t nb23;
52
53 float scale;
54 float max_bias;
55 float logit_softcap;
56
57 uint32_t mask_n_head_log2;
58 float m0;
59 float m1;
60
61 uint32_t gqa_ratio;
62 uint32_t split_kv;
63 uint32_t k_num;
64} p;
65
66#define SINK_ENABLE_BIT (1<<24)
67#define N_LOG2_MASK 0xFFFF
68
69layout (binding = 4) readonly buffer S {float data_s[];};
70
71layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
72
73layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
74
75#define MASK_OPT_ALL_NEG_INF 1
76#define MASK_OPT_ALL_ZERO 2
77
78#define BINDING_IDX_K 0
79#define BINDING_IDX_V 1
80#if defined(DATA_A_F32)
81layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
82layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
83#elif defined(A_TYPE_PACKED16)
84layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
85layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
86#endif
87
88#ifndef BLOCK_SIZE
89#define BLOCK_SIZE 1
90#endif
91
92#if defined(DATA_A_F32)
93#undef BLOCK_SIZE
94#define BLOCK_SIZE 4
95#define BLOCK_BYTE_SIZE 16
96
97vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
98 // iqs is currently always zero in the flash attention shaders
99 if (binding_idx == BINDING_IDX_K) {
100 return k_packed.k_data_packed[a_offset + ib];
101 } else {
102 return v_packed.v_data_packed[a_offset + ib];
103 }
104}
105#endif
106
107#if defined(DATA_A_Q4_0)
108#define BLOCK_BYTE_SIZE 18
109
110vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
111 if (binding_idx == BINDING_IDX_K) {
112 uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
113 uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
114 uint shift = (iqs & 0x10) >> 2;
115 vui_lo >>= shift;
116 vui_hi >>= shift;
117
118 return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
119 } else {
120 uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
121 uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
122 uint shift = (iqs & 0x10) >> 2;
123 vui_lo >>= shift;
124 vui_hi >>= shift;
125
126 return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
127 }
128}
129#endif
130
131#if defined(DATA_A_Q8_0)
132#define BLOCK_BYTE_SIZE 34
133vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
134 if (binding_idx == BINDING_IDX_K) {
135 const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
136 const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
137
138 return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
139 } else {
140 const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
141 const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
142
143 return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
144 }
145}
146#endif
147
148#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
149
150
151// Store column zero. This is used to save per-row m and L values for split_k.
152ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
153{
154 if (r < N && c == 0) {
155 uint32_t offset = iq2 + r;
156 data_o[o_offset + offset] = D_TYPE(elem);
157 }
158 return elem;
159}
160
161// Load the slope matrix, indexed by Q's dimension 2.
162ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
163{
164 const uint32_t h = iq2 + (r % p.gqa_ratio);
165
166 uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
167
168 const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
169 const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
170
171 return ACC_TYPE(pow(base, ACC_TYPE(exph)));
172}
173
174// Load the sink value, indexed by Q's dimension 2.
175ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
176{
177 const uint32_t h = iq2 + (r % p.gqa_ratio);
178
179 return ACC_TYPE(data_s[h]);
180}
181
182uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
183 gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
184 q_stride, k_stride, v_stride, m_stride;
185
186void init_indices()
187{
188 N = p.N;
189 KV = p.KV;
190
191 if (p.k_num > 1) {
192 i = 0;
193 // batch and split_k share gl_WorkGroupID.x
194 gqa_iq1 = gl_WorkGroupID.x / p.k_num;
195 split_k_index = gl_WorkGroupID.x % p.k_num;
196 } else if (p.gqa_ratio > 1) {
197 i = 0;
198 gqa_iq1 = gl_WorkGroupID.x;
199 split_k_index = 0;
200 } else {
201 i = gl_WorkGroupID.x;
202 gqa_iq1 = 0;
203 split_k_index = 0;
204 }
205
206 Tr = CEIL_DIV(N, Br);
207
208 start_j = split_k_index * p.split_kv / Bc;
209 end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
210
211 // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
212 // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
213 iq2 = gl_WorkGroupID.y * p.gqa_ratio;
214 iq3 = gl_WorkGroupID.z;
215
216 // broadcast factors
217 rk2 = p.neq2/p.nek2;
218 rk3 = p.neq3/p.nek3;
219
220 rv2 = p.neq2/p.nev2;
221 rv3 = p.neq3/p.nev3;
222
223 // k indices
224 ik3 = iq3 / rk3;
225 ik2 = iq2 / rk2;
226
227 // v indices
228 iv3 = iq3 / rv3;
229 iv2 = iq2 / rv2;
230
231 // nb?1 are already divided by the type size and are in units of elements.
232 // When using grouped query attention, Q is indexed by iq2, so the stride
233 // should be nb02 (which is in bytes).
234 q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
235 k_stride = p.nb11;
236 v_stride = p.nb21;
237 // When using grouped query attention, all rows use the same mask (stride 0).
238 // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
239 // that prevents the compiler from folding the "&" through the select
240 // and breaking the alignment detection.
241 m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
242}
243
244// Bias applied to softmax to stay in fp16 range.
245// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
246const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;