1#include "common.cuh"
 2
 3static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
 4    const block_q4_0 * x = (const block_q4_0 *) vx;
 5
 6    const float d = x[ib].d;
 7
 8    const int vui = x[ib].qs[iqs];
 9
10    v.x = vui & 0xF;
11    v.y = vui >> 4;
12
13    v.x = (v.x - 8.0f) * d;
14    v.y = (v.y - 8.0f) * d;
15}
16
17static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
18    const block_q4_1 * x = (const block_q4_1 *) vx;
19
20    const float2 dm = __half22float2(x[ib].dm);
21
22    const int vui = x[ib].qs[iqs];
23
24    v.x = vui & 0xF;
25    v.y = vui >> 4;
26
27    v.x = (v.x * dm.x) + dm.y;
28    v.y = (v.y * dm.x) + dm.y;
29}
30
31static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
32    const block_q5_0 * x = (const block_q5_0 *) vx;
33
34    const float d = x[ib].d;
35
36    uint32_t qh;
37    memcpy(&qh, x[ib].qh, sizeof(qh));
38
39    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
40    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
41
42    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
43    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
44
45    v.x = (v.x - 16.0f) * d;
46    v.y = (v.y - 16.0f) * d;
47}
48
49static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
50    const block_q5_1 * x = (const block_q5_1 *) vx;
51
52    const float2 dm = __half22float2(x[ib].dm);
53
54    uint32_t qh;
55    memcpy(&qh, x[ib].qh, sizeof(qh));
56
57    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
58    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
59
60    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
61    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
62
63    v.x = (v.x * dm.x) + dm.y;
64    v.y = (v.y * dm.x) + dm.y;
65}
66
67static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
68    const block_q8_0 * x = (const block_q8_0 *) vx;
69
70    const float d = x[ib].d;
71
72    v.x = x[ib].qs[iqs + 0];
73    v.y = x[ib].qs[iqs + 1];
74
75    v.x *= d;
76    v.y *= d;
77}