1#include "convert.cuh"
  2#include "dequantize.cuh"
  3
  4#include <cstdint>
  5
  6#define CUDA_Q8_0_NE_ALIGN 2048
  7
  8template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  9static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
 10        const int64_t ne00, const int64_t ne01, const int64_t ne02,
 11        const int64_t s01, const int64_t s02, const int64_t s03) {
 12    const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 13
 14    if (i00 >= ne00) {
 15        return;
 16    }
 17
 18    const int64_t i01 = blockIdx.y;
 19    const int64_t i02 = blockIdx.z % ne02;
 20    const int64_t i03 = blockIdx.z / ne02;
 21
 22    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
 23
 24    const int64_t ib = ibx0 + i00/qk; // block index
 25    const int64_t iqs = (i00%qk)/qr; // quant index
 26    const int64_t iybs = i00 - i00%qk; // y block start index
 27    const int64_t y_offset = qr == 1 ? 1 : qk/2;
 28
 29    // dequantize
 30    float2 v;
 31    dequantize_kernel(vx, ib, iqs, v);
 32
 33    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
 34    y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
 35    y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
 36}
 37
 38template <bool need_check>
 39static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
 40#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 41    constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
 42
 43    const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
 44    const int * x0 = ((int *) vx) + blockIdx.x * nint;
 45    half2 * y2 = (half2 *) (y + i0);
 46
 47    __shared__ int vals[nint];
 48
 49#pragma unroll
 50    for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
 51        if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
 52            break;
 53        }
 54
 55        const int ix = ix0 + threadIdx.x;
 56        vals[ix] = x0[ix];
 57    }
 58
 59    __syncthreads();
 60
 61#pragma unroll
 62    for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
 63        if (need_check && i0 + iy + 2*threadIdx.x >= k) {
 64            return;
 65        }
 66
 67        const half * b0 = ((const half  *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
 68        const half    d = *b0;
 69        const char2  qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
 70
 71        y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
 72    }
 73#else
 74    GGML_UNUSED_VARS(vx, y, k);
 75    NO_DEVICE_CODE;
 76#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 77}
 78
 79template<typename dst_t>
 80static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
 81
 82    const int64_t i = blockIdx.x;
 83
 84    // assume 32 threads
 85    const int64_t tid = threadIdx.x;
 86    const int64_t il  = tid/8;
 87    const int64_t ir  = tid%8;
 88    const int64_t ib = 8*i + ir;
 89    if (ib >= nb32) {
 90        return;
 91    }
 92
 93    dst_t * y = yy + 256*i + 32*ir + 4*il;
 94
 95    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
 96    const float d = __half2float(x->d);
 97    const float dm = -8*d;
 98
 99    const uint8_t * q = x->qs + 4*il;
100
101    for (int l = 0; l < 4; ++l) {
102        y[l+ 0] = d * (q[l] & 0xF) + dm;
103        y[l+16] = d * (q[l] >>  4) + dm;
104    }
105}
106
107template<typename dst_t>
108static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
109
110    const int64_t i = blockIdx.x;
111
112    // assume 32 threads
113    const int64_t tid = threadIdx.x;
114    const int64_t il  = tid/8;
115    const int64_t ir  = tid%8;
116    const int64_t ib = 8*i + ir;
117    if (ib >= nb32) {
118        return;
119    }
120
121    dst_t * y = yy + 256*i + 32*ir + 4*il;
122
123    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
124    const float2 d = __half22float2(x->dm);
125
126    const uint8_t * q = x->qs + 4*il;
127
128    for (int l = 0; l < 4; ++l) {
129        y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
130        y[l+16] = d.x * (q[l] >>  4) + d.y;
131    }
132}
133
134//================================== k-quants
135
136template<typename dst_t>
137static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
138
139    const int64_t i   = blockIdx.x;
140    const block_q2_K * x = (const block_q2_K *) vx;
141
142    const int64_t tid = threadIdx.x;
143    const int64_t n   = tid/32;
144    const int64_t l   = tid - 32*n;
145    const int64_t is  = 8*n + l/16;
146
147    const uint8_t q = x[i].qs[32*n + l];
148    dst_t * y = yy + i*QK_K + 128*n;
149
150    float dall = __low2half(x[i].dm);
151    float dmin = __high2half(x[i].dm);
152    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
153    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
154    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
155    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
156}
157
158template<typename dst_t>
159static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
160
161    const int64_t i = blockIdx.x;
162    const block_q3_K * x = (const block_q3_K *) vx;
163
164    const int64_t r = threadIdx.x/4;
165    const int64_t tid = r/2;
166    const int64_t is0 = r%2;
167    const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
168    const int64_t n = tid / 4;
169    const int64_t j = tid - 4*n;
170
171    uint8_t m = 1 << (4*n + j);
172    int64_t is = 8*n + 2*j + is0;
173    int shift = 2*j;
174
175    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
176                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
177                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
178                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
179    float d_all = x[i].d;
180    float dl = d_all * (us - 32);
181
182    dst_t * y = yy + i*QK_K + 128*n + 32*j;
183    const uint8_t * q = x[i].qs + 32*n;
184    const uint8_t * hm = x[i].hmask;
185
186    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
187}
188
189static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
190    if (j < 4) {
191        d = q[j] & 63; m = q[j + 4] & 63;
192    } else {
193        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
194        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
195    }
196}
197
198template<typename dst_t>
199static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
200    const block_q4_K * x = (const block_q4_K *) vx;
201
202    const int64_t i = blockIdx.x;
203
204    // assume 32 threads
205    const int64_t tid = threadIdx.x;
206    const int64_t il  = tid/8;
207    const int64_t ir  = tid%8;
208    const int64_t is  = 2*il;
209    const int64_t n   = 4;
210
211    dst_t * y = yy + i*QK_K + 64*il + n*ir;
212
213    const float dall = __low2half(x[i].dm);
214    const float dmin = __high2half(x[i].dm);
215
216    const uint8_t * q = x[i].qs + 32*il + n*ir;
217
218    uint8_t sc, m;
219    get_scale_min_k4(is + 0, x[i].scales, sc, m);
220    const float d1 = dall * sc; const float m1 = dmin * m;
221    get_scale_min_k4(is + 1, x[i].scales, sc, m);
222    const float d2 = dall * sc; const float m2 = dmin * m;
223    for (int l = 0; l < n; ++l) {
224        y[l + 0] = d1 * (q[l] & 0xF) - m1;
225        y[l +32] = d2 * (q[l] >>  4) - m2;
226    }
227}
228
229template<typename dst_t>
230static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
231    const block_q5_K * x = (const block_q5_K *) vx;
232
233    const int64_t i = blockIdx.x;
234
235    // assume 64 threads - this is very slightly better than the one below
236    const int64_t tid = threadIdx.x;
237    const int64_t il  = tid/16;   // il is in 0...3
238    const int64_t ir  = tid%16;   // ir is in 0...15
239    const int64_t is  = 2*il;     // is is in 0...6
240
241    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
242
243    const float dall = __low2half(x[i].dm);
244    const float dmin = __high2half(x[i].dm);
245
246    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
247    const uint8_t * qh = x[i].qh + 2*ir;
248
249    uint8_t sc, m;
250    get_scale_min_k4(is + 0, x[i].scales, sc, m);
251    const float d1 = dall * sc; const float m1 = dmin * m;
252    get_scale_min_k4(is + 1, x[i].scales, sc, m);
253    const float d2 = dall * sc; const float m2 = dmin * m;
254
255    uint8_t   hm  = 1 << (2*il);
256    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
257    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
258    hm <<= 1;
259    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
260    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
261}
262
263template<typename dst_t>
264static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
265    const block_q6_K * x = (const block_q6_K *) vx;
266
267    const int64_t i = blockIdx.x;
268
269    // assume 64 threads - this is very slightly better than the one below
270    const int64_t tid = threadIdx.x;
271    const int64_t ip  = tid/32;   // ip is 0 or 1
272    const int64_t il  = tid - 32*ip; // 0...32
273    const int64_t is  = 8*ip + il/16;
274
275    dst_t * y = yy + i*QK_K + 128*ip + il;
276
277    const float d = x[i].d;
278
279    const uint8_t * ql = x[i].ql + 64*ip + il;
280    const uint8_t   qh = x[i].qh[32*ip + il];
281    const int8_t  * sc = x[i].scales + is;
282
283    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
284    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
285    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
286    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
287}
288
289template<typename dst_t>
290static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
291
292    const int64_t i   = blockIdx.x;
293    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
294
295    const int64_t tid = threadIdx.x;
296    const int64_t il = tid/8; // 0...3
297    const int64_t ib = tid%8; // 0...7
298    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
299    const uint16_t * q2 = x[i].qs + 4*ib;
300    const uint8_t  * aux8 = (const uint8_t *)q2;
301    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
302    const uint32_t aux32 = q2[2] | (q2[3] << 16);
303    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
304    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
305    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
306}
307
308template<typename dst_t>
309static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
310
311    const int64_t i   = blockIdx.x;
312    const block_iq2_xs * x = (const block_iq2_xs *) vx;
313
314    const int64_t tid = threadIdx.x;
315    const int64_t il = tid/8; // 0...3
316    const int64_t ib = tid%8; // 0...7
317    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
318    const uint16_t * q2 = x[i].qs + 4*ib;
319    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
320    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
321    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
322    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
323}
324
325template<typename dst_t>
326static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
327
328    const int64_t i   = blockIdx.x;
329    const block_iq2_s * x = (const block_iq2_s *) vx;
330
331    const int64_t tid = threadIdx.x;
332    const int64_t il = tid/8; // 0...3
333    const int64_t ib = tid%8; // 0...7
334    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
335    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
336    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
337    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
338    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
339}
340
341template<typename dst_t>
342static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
343
344    const int64_t i   = blockIdx.x;
345    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
346
347    const int64_t tid = threadIdx.x;
348    const int64_t il = tid/8; // 0...3
349    const int64_t ib = tid%8; // 0...7
350    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
351    const uint8_t  * q3 = x[i].qs + 8*ib;
352    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
353    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
354    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
355    const uint32_t aux32 = gas[0] | (gas[1] << 16);
356    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
357    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
358    for (int j = 0; j < 4; ++j) {
359        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
360        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
361    }
362}
363
364template<typename dst_t>
365static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
366
367    const int64_t i   = blockIdx.x;
368    const block_iq3_s * x = (const block_iq3_s *) vx;
369
370    const int64_t tid = threadIdx.x;
371    const int64_t il = tid/8; // 0...3
372    const int64_t ib = tid%8; // 0...7
373    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
374    const uint8_t * qs = x[i].qs + 8*ib;
375    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
376    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
377    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
378    const uint8_t signs = x[i].signs[4*ib + il];
379    for (int j = 0; j < 4; ++j) {
380        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
381        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
382    }
383}
384
385template<typename dst_t>
386static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
387
388    const int64_t i   = blockIdx.x;
389    const block_iq1_s * x = (const block_iq1_s  *) vx;
390
391    const int64_t tid = threadIdx.x;
392    const int64_t il = tid/8; // 0...3
393    const int64_t ib = tid%8; // 0...7
394    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
395    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
396    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
397    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
398    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
399    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
400    grid32[0] &= 0x0f0f0f0f;
401    for (int j = 0; j < 8; ++j) {
402        y[j] = d * (q[j] + delta);
403    }
404}
405
406template<typename dst_t>
407static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
408
409    const int64_t i   = blockIdx.x;
410    const block_iq1_m * x = (const block_iq1_m  *) vx;
411
412    const int64_t tid = threadIdx.x;
413    const int64_t il = tid/8; // 0...3
414    const int64_t ib = tid%8; // 0...7
415    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
416    const uint16_t * sc = (const uint16_t *)x[i].scales;
417    iq1m_scale_t scale;
418    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
419    const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
420    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
421    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
422    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
423    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
424    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
425    grid32[0] &= 0x0f0f0f0f;
426    for (int j = 0; j < 8; ++j) {
427        y[j] = d * (q[j] + delta);
428    }
429}
430
431template<typename dst_t>
432static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
433
434    const int64_t i   = blockIdx.x;
435    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
436
437    const int64_t tid = threadIdx.x;
438    const int64_t il = tid/8; // 0...3
439    const int64_t ib = tid%8; // 0...7
440    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
441    const uint8_t  * q4 = x[ib].qs + 4*il;
442    const float d = (float)x[ib].d;
443    for (int j = 0; j < 4; ++j) {
444        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
445        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
446    }
447}
448
449template<typename dst_t>
450static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
451    const int64_t i   = blockIdx.x;
452    const block_iq4_xs * x = (const block_iq4_xs *)vx;
453
454    const int64_t tid = threadIdx.x;
455    const int64_t il = tid/8; // 0...3
456    const int64_t ib = tid%8; // 0...7
457    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
458    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
459    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
460    for (int j = 0; j < 4; ++j) {
461        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
462        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
463    }
464}
465
466template<typename dst_t>
467static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
468
469    const int64_t i   = blockIdx.x;
470    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
471
472    const int64_t tid = threadIdx.x;
473    const int64_t il = tid/8; // 0...3
474    const int64_t ib = tid%8; // 0...7
475    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
476    const uint8_t  * q4 = x[ib].qs + 4*il;
477    const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
478    for (int j = 0; j < 4; ++j) {
479        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
480        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;
481    }
482}
483
484template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
485static void dequantize_block_cuda(const void * vx, dst_t * y,
486        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
487        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
488    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
489    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
490        (vx, y, ne00, ne01, ne02, s01, s02, s03);
491}
492
493template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
494static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
495    dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
496}
497
498static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
499    const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
500    if (k % CUDA_Q8_0_NE_ALIGN == 0) {
501        const bool need_check = false;
502        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
503    } else {
504        const bool need_check = true;
505        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
506    }
507}
508
509template<typename dst_t>
510static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
511    const int nb = k / QK_K;
512    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
513}
514
515template<typename dst_t>
516static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
517    const int nb = k / QK_K;
518    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
519}
520
521template<typename dst_t>
522static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
523    const int nb32 = k / 32;
524    const int nb = (k + 255) / 256;
525    dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
526}
527
528template<typename dst_t>
529static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
530    const int nb32 = k / 32;
531    const int nb = (k + 255) / 256;
532    dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
533}
534
535template<typename dst_t>
536static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
537    const int nb = k / QK_K;
538    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
539}
540
541template<typename dst_t>
542static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
543    const int nb = k / QK_K;
544    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
545}
546
547template<typename dst_t>
548static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
549    const int nb = k / QK_K;
550    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
551}
552
553template<typename dst_t>
554static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
555    const int nb = k / QK_K;
556    dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
557}
558
559template<typename dst_t>
560static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
561    const int nb = k / QK_K;
562    dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
563}
564
565template<typename dst_t>
566static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
567    const int nb = k / QK_K;
568    dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
569}
570
571template<typename dst_t>
572static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
573    const int nb = k / QK_K;
574    dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
575}
576
577template<typename dst_t>
578static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
579    const int nb = k / QK_K;
580    dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
581}
582
583template<typename dst_t>
584static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
585    const int nb = k / QK_K;
586    dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
587}
588
589template<typename dst_t>
590static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
591    const int nb = (k + QK_K - 1) / QK_K;
592    dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
593}
594
595template<typename dst_t>
596static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
597    const int nb = k / QK_K;
598    dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
599}
600
601template<typename dst_t>
602static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
603    const int nb = (k + QK_K - 1) / QK_K;
604    dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
605}
606
607template<typename dst_t>
608static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
609    const int nb = (k + QK_K - 1) / QK_K;
610    dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
611}
612
613template <typename src_t, typename dst_t>
614static __global__ void convert_unary(
615        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
616        const int64_t s01, const int64_t s02, const int64_t s03) {
617    const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
618
619    if (i00 >= ne00) {
620        return;
621    }
622
623    const int64_t i01 = blockIdx.y;
624    const int64_t i02 = blockIdx.z % ne02;
625    const int64_t i03 = blockIdx.z / ne02;
626
627    const src_t * x = (const src_t *) vx;
628
629    const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
630    const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
631    y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
632}
633
634template <typename src_t, typename dst_t>
635static void convert_unary_cuda(const void * vx, dst_t * y,
636        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
637        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
638    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
639    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
640        (vx, y, ne00, ne01, ne02, s01, s02, s03);
641}
642
643template <typename src_t, typename dst_t>
644static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
645    convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
646}
647
648to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
649    switch (type) {
650        case GGML_TYPE_F32:
651            return convert_unary_cont_cuda<float>;
652        case GGML_TYPE_F16:
653            return convert_unary_cont_cuda<half>;
654        default:
655            return nullptr;
656    }
657}
658
659to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
660    switch (type) {
661        case GGML_TYPE_Q4_0:
662            return dequantize_row_q4_0_cuda;
663        case GGML_TYPE_Q4_1:
664            return dequantize_row_q4_1_cuda;
665        case GGML_TYPE_Q5_0:
666            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
667        case GGML_TYPE_Q5_1:
668            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
669        case GGML_TYPE_Q8_0:
670            if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
671                return dequantize_block_q8_0_f16_cuda;
672            }
673            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
674        case GGML_TYPE_Q2_K:
675            return dequantize_row_q2_K_cuda;
676        case GGML_TYPE_Q3_K:
677            return dequantize_row_q3_K_cuda;
678        case GGML_TYPE_Q4_K:
679            return dequantize_row_q4_K_cuda;
680        case GGML_TYPE_Q5_K:
681            return dequantize_row_q5_K_cuda;
682        case GGML_TYPE_Q6_K:
683            return dequantize_row_q6_K_cuda;
684        case GGML_TYPE_IQ2_XXS:
685            return dequantize_row_iq2_xxs_cuda;
686        case GGML_TYPE_IQ2_XS:
687            return dequantize_row_iq2_xs_cuda;
688        case GGML_TYPE_IQ2_S:
689            return dequantize_row_iq2_s_cuda;
690        case GGML_TYPE_IQ3_XXS:
691            return dequantize_row_iq3_xxs_cuda;
692        case GGML_TYPE_IQ1_S:
693            return dequantize_row_iq1_s_cuda;
694        case GGML_TYPE_IQ1_M:
695            return dequantize_row_iq1_m_cuda;
696        case GGML_TYPE_IQ4_NL:
697            return dequantize_row_iq4_nl_cuda;
698        case GGML_TYPE_IQ4_XS:
699            return dequantize_row_iq4_xs_cuda;
700        case GGML_TYPE_IQ3_S:
701            return dequantize_row_iq3_s_cuda;
702        case GGML_TYPE_MXFP4:
703            return dequantize_row_mxfp4_cuda;
704        case GGML_TYPE_F32:
705            return convert_unary_cont_cuda<float>;
706        case GGML_TYPE_BF16:
707            return convert_unary_cont_cuda<nv_bfloat16>;
708        default:
709            return nullptr;
710    }
711}
712
713to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
714    switch (type) {
715        case GGML_TYPE_Q4_0:
716            return dequantize_row_q4_0_cuda;
717        case GGML_TYPE_Q4_1:
718            return dequantize_row_q4_1_cuda;
719        case GGML_TYPE_Q5_0:
720            return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
721        case GGML_TYPE_Q5_1:
722            return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
723        case GGML_TYPE_Q8_0:
724            return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
725        case GGML_TYPE_Q2_K:
726            return dequantize_row_q2_K_cuda;
727        case GGML_TYPE_Q3_K:
728            return dequantize_row_q3_K_cuda;
729        case GGML_TYPE_Q4_K:
730            return dequantize_row_q4_K_cuda;
731        case GGML_TYPE_Q5_K:
732            return dequantize_row_q5_K_cuda;
733        case GGML_TYPE_Q6_K:
734            return dequantize_row_q6_K_cuda;
735        case GGML_TYPE_IQ2_XXS:
736            return dequantize_row_iq2_xxs_cuda;
737        case GGML_TYPE_IQ2_XS:
738            return dequantize_row_iq2_xs_cuda;
739        case GGML_TYPE_IQ2_S:
740            return dequantize_row_iq2_s_cuda;
741        case GGML_TYPE_IQ3_XXS:
742            return dequantize_row_iq3_xxs_cuda;
743        case GGML_TYPE_IQ1_S:
744            return dequantize_row_iq1_s_cuda;
745        case GGML_TYPE_IQ1_M:
746            return dequantize_row_iq1_m_cuda;
747        case GGML_TYPE_IQ4_NL:
748            return dequantize_row_iq4_nl_cuda;
749        case GGML_TYPE_IQ4_XS:
750            return dequantize_row_iq4_xs_cuda;
751        case GGML_TYPE_IQ3_S:
752            return dequantize_row_iq3_s_cuda;
753        case GGML_TYPE_MXFP4:
754            return dequantize_row_mxfp4_cuda;
755        case GGML_TYPE_F16:
756            return convert_unary_cont_cuda<half>;
757        case GGML_TYPE_BF16:
758            return convert_unary_cont_cuda<nv_bfloat16>;
759        default:
760            return nullptr;
761    }
762}
763
764to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
765    switch (type) {
766        case GGML_TYPE_F32:
767            return convert_unary_cuda<float>;
768        case GGML_TYPE_Q4_0:
769            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
770        case GGML_TYPE_Q4_1:
771            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
772        case GGML_TYPE_Q5_0:
773            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
774        case GGML_TYPE_Q5_1:
775            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
776        case GGML_TYPE_Q8_0:
777            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
778        case GGML_TYPE_BF16:
779            return convert_unary_cuda<nv_bfloat16>;
780        default:
781            return nullptr;
782    }
783}
784
785to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
786    switch (type) {
787        case GGML_TYPE_F32:
788            return convert_unary_cuda<float, nv_bfloat16>;
789        case GGML_TYPE_Q4_0:
790            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
791        case GGML_TYPE_Q4_1:
792            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
793        case GGML_TYPE_Q5_0:
794            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
795        case GGML_TYPE_Q5_1:
796            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
797        case GGML_TYPE_Q8_0:
798            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
799        case GGML_TYPE_F16:
800            return convert_unary_cuda<half, nv_bfloat16>;
801        default:
802            return nullptr;
803    }
804}
805
806to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
807    switch (type) {
808        case GGML_TYPE_F16:
809            return convert_unary_cuda<half, float>;
810        case GGML_TYPE_Q4_0:
811            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
812        case GGML_TYPE_Q4_1:
813            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
814        case GGML_TYPE_Q5_0:
815            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
816        case GGML_TYPE_Q5_1:
817            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
818        case GGML_TYPE_Q8_0:
819            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
820        case GGML_TYPE_BF16:
821            return convert_unary_cuda<nv_bfloat16, float>;
822        default:
823            return nullptr;
824    }
825}