1#version 450
  2
  3#include "rte.glsl"
  4#include "types.glsl"
  5
  6#if defined(SET_ROWS) && QUANT_K == 1
  7layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  8const uint BLOCK_SIZE = 512;
  9#else
 10layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
 11const uint BLOCK_SIZE = 32;
 12#endif
 13
 14layout (binding = 0) readonly buffer S {float data_s[];};
 15
 16#if defined(SET_ROWS)
 17#include "generic_binary_head.glsl"
 18layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
 19layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
 20
 21#if B_SIZE == 64
 22#define DATA_I_SWIZZLE .x
 23#else
 24#define DATA_I_SWIZZLE
 25#endif
 26
 27#else
 28#include "generic_unary_head.glsl"
 29layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
 30#endif
 31
 32#if defined(DATA_A_Q4_0)
 33void quantize(uint dst_idx, uint src_idx)
 34{
 35    float amax = 0.0;
 36    float vmax = 0.0;
 37
 38    [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
 39        const float v = data_s[src_idx + j];
 40        if (amax < abs(v)) {
 41            amax = abs(v);
 42            vmax = v;
 43        }
 44    }
 45
 46    const float d  = vmax / -8;
 47    const float id = (d != 0.0) ? 1.0/d : 0.0;
 48
 49    data_q[dst_idx].d = float16_t(d);
 50
 51    [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
 52        const float x0 = data_s[src_idx + 0              + j]*id;
 53        const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
 54
 55        const uint xi0 = min(15, int(x0 + 8.5));
 56        const uint xi1 = min(15, int(x1 + 8.5));
 57
 58        data_q[dst_idx].qs[j]  = uint8_t(xi0 | (xi1 << 4));
 59    }
 60}
 61#endif
 62
 63#if defined(DATA_A_Q4_1)
 64void quantize(uint dst_idx, uint src_idx)
 65{
 66    float vmin = 1.0/0.0;
 67    float vmax = -vmin;
 68
 69    [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
 70        const float v = data_s[src_idx + j];
 71
 72        if (v < vmin) vmin = v;
 73        if (v > vmax) vmax = v;
 74    }
 75
 76    const float d  = (vmax - vmin) / ((1 << 4) - 1);
 77    const float id = (d != 0.0) ? 1.0/d : 0.0;
 78
 79    data_q[dst_idx].d = float16_t(d);
 80    data_q[dst_idx].m = float16_t(vmin);
 81
 82    [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
 83        const float x0 = (data_s[src_idx + 0              + j] - vmin)*id;
 84        const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
 85
 86        const uint xi0 = min(15, int(x0 + 0.5));
 87        const uint xi1 = min(15, int(x1 + 0.5));
 88
 89        data_q[dst_idx].qs[j]  = uint8_t(xi0 | (xi1 << 4));
 90    }
 91}
 92#endif
 93
 94#if defined(DATA_A_Q5_0)
 95void quantize(uint dst_idx, uint src_idx)
 96{
 97    float amax = 0.0;
 98    float vmax = 0.0;
 99
100    [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
101        const float v = data_s[src_idx + j];
102        if (amax < abs(v)) {
103            amax = abs(v);
104            vmax = v;
105        }
106    }
107
108    const float d  = vmax / -16;
109    const float id = (d != 0.0) ? 1.0/d : 0.0;
110
111    data_q[dst_idx].d = float16_t(d);
112
113    uint32_t qh = 0;
114    [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
115        const float x0 = data_s[src_idx + 0              + j]*id;
116        const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
117
118        const uint xi0 = min(31, int(x0 + 16.5));
119        const uint xi1 = min(31, int(x1 + 16.5));
120
121        data_q[dst_idx].qs[j]  = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
122        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
123        qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
124    }
125    data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
126    data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
127}
128#endif
129
130#if defined(DATA_A_Q5_1)
131void quantize(uint dst_idx, uint src_idx)
132{
133    float min = data_s[src_idx + 0];
134    float max = min;
135
136    [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
137        const float v = data_s[src_idx + j];
138        min = v < min ? v : min;
139        max = v > max ? v : max;
140    }
141
142    const float d  = (max - min) / 31;
143    const float id = (d != 0) ? 1.0/d : 0.0;
144
145    data_q[dst_idx].d = float16_t(d);
146    data_q[dst_idx].m = float16_t(min);
147
148    uint32_t qh = 0;
149    [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
150        const float x0 = (data_s[src_idx + 0              + j] - min)*id;
151        const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
152
153        const uint xi0 = uint(x0 + 0.5);
154        const uint xi1 = uint(x1 + 0.5);
155
156        data_q[dst_idx].qs[j]  = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
157        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
158        qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
159    }
160    data_q[dst_idx].qh = qh;
161}
162#endif
163
164#if defined(DATA_A_Q8_0)
165void quantize(uint dst_idx, uint src_idx)
166{
167    float amax = 0.0; // absolute max
168
169    [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
170        const float v = data_s[src_idx + j];
171        amax = max(amax, abs(v));
172    }
173
174    const float d = amax / ((1 << 7) - 1);
175    const float id = (d != 0.0) ? 1.0/d : 0.0;
176
177    data_q[dst_idx].d = float16_t(d);
178
179    [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
180        const float x0 = data_s[src_idx + j]*id;
181
182        data_q[dst_idx].qs[j] = int8_t(round(x0));
183    }
184}
185#endif
186
187#if defined(DATA_A_IQ4_NL)
188uint best_index(float x) {
189    if (x <= kvalues_iq4nl[0]) return 0;
190    if (x >= kvalues_iq4nl[15]) return 15;
191    int ml = 0, mu = 15;
192    while (mu-ml > 1) {
193        int mav = (ml+mu)/2;
194        if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
195    }
196    return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
197}
198
199void quantize(uint dst_idx, uint src_idx)
200{
201    float amax = 0.0;
202    float vmax = 0.0;
203
204    [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
205        const float v = data_s[src_idx + j];
206        if (amax < abs(v)) {
207            amax = abs(v);
208            vmax = v;
209        }
210    }
211
212    float d = vmax / kvalues_iq4nl[0];
213    const float id = (d != 0.0) ? 1.0/d : 0.0;
214
215    float sumqx = 0, sumq2 = 0;
216    [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
217        const float x0 = data_s[src_idx + 0                + j]*id;
218        const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
219        const uint xi0 = best_index(x0);
220        const uint xi1 = best_index(x1);
221        data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
222        const float v0 = kvalues_iq4nl[xi0];
223        const float v1 = kvalues_iq4nl[xi1];
224        const float w0 = data_s[src_idx + 0                + j]*data_s[src_idx + 0                + j];
225        const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
226        sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
227        sumq2 += w0*v0*v0 + w1*v1*v1;
228    }
229
230    data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
231
232}
233#endif
234
235#if defined(DATA_A_F32) || defined(DATA_A_F16)
236void quantize(uint dst_idx, uint src_idx)
237{
238    data_q[dst_idx] = A_TYPE(data_s[src_idx]);
239}
240#endif
241
242#if defined(DATA_A_BF16)
243void quantize(uint dst_idx, uint src_idx)
244{
245    data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
246}
247#endif
248
249#if defined(SET_ROWS)
250
251void main() {
252#ifdef NEEDS_INIT_IQ_SHMEM
253    init_iq_shmem(gl_WorkGroupSize);
254#endif
255
256    const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
257
258    if (idx >= p.ne) {
259        return;
260    }
261
262    uint i00, i01, i02, i03;
263    get_indices(idx, i00, i01, i02, i03);
264
265    uint i12 = fastmod(i03, p.ne12);
266    uint i11 = fastmod(i02, p.ne11);
267    uint i10 = i01;
268
269    uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
270
271    uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
272    uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
273
274    quantize(dst_idx, src0_idx);
275}
276
277#else
278
279void main() {
280#ifdef NEEDS_INIT_IQ_SHMEM
281    init_iq_shmem(gl_WorkGroupSize);
282#endif
283
284    const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
285
286    if (idx >= p.ne) {
287        return;
288    }
289
290    uint dst_idx = dst_idx_quant(idx, QUANT_K);
291    uint src_idx = get_aoffset() + src0_idx(idx);
292
293    quantize(dst_idx, src_idx);
294}
295
296#endif