1
  2layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  3
  4layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
  5layout (constant_id = 1) const uint32_t Br = 1;
  6layout (constant_id = 2) const uint32_t Bc = 32;
  7layout (constant_id = 3) const uint32_t HSK = 32;
  8layout (constant_id = 4) const uint32_t HSV = 32;
  9layout (constant_id = 5) const uint32_t Clamp = 0;
 10layout (constant_id = 6) const uint32_t D_split = 16;
 11layout (constant_id = 7) const uint32_t SubGroupSize = 32;
 12layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
 13layout (constant_id = 9) const uint32_t Flags = 0;
 14
 15const bool USE_MASK_OPT  = (Flags & 1) != 0;
 16const bool MASK_ENABLE   = (Flags & 2) != 0;
 17const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
 18
 19// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 20const uint32_t HSK_pad = (HSK + 15) & ~15;
 21const uint32_t HSV_pad = (HSV + 15) & ~15;
 22
 23const bool KV_bounds_check = Clamp != 0;
 24
 25layout (push_constant) uniform parameter {
 26    uint32_t N;
 27    uint32_t KV;
 28
 29    uint32_t ne1;
 30    uint32_t ne2;
 31    uint32_t ne3;
 32
 33    uint32_t neq2;
 34    uint32_t neq3;
 35    uint32_t nek2;
 36    uint32_t nek3;
 37    uint32_t nev2;
 38    uint32_t nev3;
 39    uint32_t nem1;
 40    uint32_t nem2;
 41    uint32_t nem3;
 42
 43    uint32_t nb01;
 44    uint32_t nb02;
 45    uint32_t nb03;
 46    uint32_t nb11;
 47    uint32_t nb12;
 48    uint32_t nb13;
 49    uint32_t nb21;
 50    uint32_t nb22;
 51    uint32_t nb23;
 52
 53    float scale;
 54    float max_bias;
 55    float logit_softcap;
 56
 57    uint32_t mask_n_head_log2;
 58    float m0;
 59    float m1;
 60
 61    uint32_t gqa_ratio;
 62    uint32_t split_kv;
 63    uint32_t k_num;
 64} p;
 65
 66#define SINK_ENABLE_BIT (1<<24)
 67#define N_LOG2_MASK 0xFFFF
 68
 69layout (binding = 4) readonly buffer S {float data_s[];};
 70
 71layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 72
 73layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
 74
 75#define MASK_OPT_ALL_NEG_INF 1
 76#define MASK_OPT_ALL_ZERO 2
 77
 78#define BINDING_IDX_K 0
 79#define BINDING_IDX_V 1
 80#if defined(DATA_A_F32)
 81layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
 82layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
 83#elif defined(A_TYPE_PACKED16)
 84layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
 85layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 86#endif
 87
 88#ifndef BLOCK_SIZE
 89#define BLOCK_SIZE 1
 90#endif
 91
 92#if defined(DATA_A_F32)
 93#undef BLOCK_SIZE
 94#define BLOCK_SIZE 4
 95#define BLOCK_BYTE_SIZE 16
 96
 97vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
 98    // iqs is currently always zero in the flash attention shaders
 99    if (binding_idx == BINDING_IDX_K) {
100        return k_packed.k_data_packed[a_offset + ib];
101    } else {
102        return v_packed.v_data_packed[a_offset + ib];
103    }
104}
105#endif
106
107#if defined(DATA_A_Q4_0)
108#define BLOCK_BYTE_SIZE 18
109
110vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
111    if (binding_idx == BINDING_IDX_K) {
112        uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
113        uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
114        uint shift = (iqs & 0x10) >> 2;
115        vui_lo >>= shift;
116        vui_hi >>= shift;
117
118        return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
119    } else {
120        uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
121        uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
122        uint shift = (iqs & 0x10) >> 2;
123        vui_lo >>= shift;
124        vui_hi >>= shift;
125
126        return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
127    }
128}
129#endif
130
131#if defined(DATA_A_Q8_0)
132#define BLOCK_BYTE_SIZE 34
133vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
134    if (binding_idx == BINDING_IDX_K) {
135        const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
136        const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
137
138        return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
139    } else {
140        const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
141        const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
142
143        return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
144    }
145}
146#endif
147
148#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
149
150
151// Store column zero. This is used to save per-row m and L values for split_k.
152ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
153{
154    if (r < N && c == 0) {
155        uint32_t offset = iq2 + r;
156        data_o[o_offset + offset] = D_TYPE(elem);
157    }
158    return elem;
159}
160
161// Load the slope matrix, indexed by Q's dimension 2.
162ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
163{
164    const uint32_t h = iq2 + (r % p.gqa_ratio);
165
166    uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
167
168    const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
169    const int      exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
170
171    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
172}
173
174// Load the sink value, indexed by Q's dimension 2.
175ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
176{
177    const uint32_t h = iq2 + (r % p.gqa_ratio);
178
179    return ACC_TYPE(data_s[h]);
180}
181
182uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
183         gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
184         q_stride, k_stride, v_stride, m_stride;
185
186void init_indices()
187{
188    N = p.N;
189    KV = p.KV;
190
191    if (p.k_num > 1) {
192        i = 0;
193        // batch and split_k share gl_WorkGroupID.x
194        gqa_iq1 = gl_WorkGroupID.x / p.k_num;
195        split_k_index = gl_WorkGroupID.x % p.k_num;
196    } else if (p.gqa_ratio > 1) {
197        i = 0;
198        gqa_iq1 = gl_WorkGroupID.x;
199        split_k_index = 0;
200    } else {
201        i = gl_WorkGroupID.x;
202        gqa_iq1 = 0;
203        split_k_index = 0;
204    }
205
206    Tr = CEIL_DIV(N, Br);
207
208    start_j = split_k_index * p.split_kv / Bc;
209    end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
210
211    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
212    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
213    iq2 = gl_WorkGroupID.y * p.gqa_ratio;
214    iq3 = gl_WorkGroupID.z;
215
216    // broadcast factors
217    rk2 = p.neq2/p.nek2;
218    rk3 = p.neq3/p.nek3;
219
220    rv2 = p.neq2/p.nev2;
221    rv3 = p.neq3/p.nev3;
222
223    // k indices
224    ik3 = iq3 / rk3;
225    ik2 = iq2 / rk2;
226
227    // v indices
228    iv3 = iq3 / rv3;
229    iv2 = iq2 / rv2;
230
231    // nb?1 are already divided by the type size and are in units of elements.
232    // When using grouped query attention, Q is indexed by iq2, so the stride
233    // should be nb02 (which is in bytes).
234    q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
235    k_stride = p.nb11;
236    v_stride = p.nb21;
237    // When using grouped query attention, all rows use the same mask (stride 0).
238    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
239    // that prevents the compiler from folding the "&" through the select
240    // and breaking the alignment detection.
241    m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
242}
243
244// Bias applied to softmax to stay in fp16 range.
245// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
246const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;