1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3typedef char int8_t;
  4typedef uchar uint8_t;
  5typedef short int16_t;
  6typedef ushort uint16_t;
  7typedef int int32_t;
  8typedef uint uint32_t;
  9
 10#define QK4_0                   32
 11
 12//------------------------------------------------------------------------------
 13// block_q4_0
 14//------------------------------------------------------------------------------
 15struct block_q4_0
 16{
 17    half d;
 18    uint8_t qs[QK4_0 / 2];
 19};
 20
 21
 22//------------------------------------------------------------------------------
 23// dequantize_q4_0_f32, dequantize_q4_0_f16
 24//------------------------------------------------------------------------------
 25void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
 26    global ushort * qs = ((global ushort *)xb + 1);
 27    float d1 = il ? (xb->d / 16.h) : xb->d;
 28    float d2 = d1 / 256.f;
 29    float md = -8.h * xb->d;
 30    ushort mask0 = il ? 0x00F0 : 0x000F;
 31    ushort mask1 = mask0 << 8;
 32
 33    reg->s0 = d1 * (qs[0] & mask0) + md;
 34    reg->s1 = d2 * (qs[0] & mask1) + md;
 35
 36    reg->s2 = d1 * (qs[1] & mask0) + md;
 37    reg->s3 = d2 * (qs[1] & mask1) + md;
 38
 39    reg->s4 = d1 * (qs[2] & mask0) + md;
 40    reg->s5 = d2 * (qs[2] & mask1) + md;
 41
 42    reg->s6 = d1 * (qs[3] & mask0) + md;
 43    reg->s7 = d2 * (qs[3] & mask1) + md;
 44
 45    reg->s8 = d1 * (qs[4] & mask0) + md;
 46    reg->s9 = d2 * (qs[4] & mask1) + md;
 47
 48    reg->sa = d1 * (qs[5] & mask0) + md;
 49    reg->sb = d2 * (qs[5] & mask1) + md;
 50
 51    reg->sc = d1 * (qs[6] & mask0) + md;
 52    reg->sd = d2 * (qs[6] & mask1) + md;
 53
 54    reg->se = d1 * (qs[7] & mask0) + md;
 55    reg->sf = d2 * (qs[7] & mask1) + md;
 56}
 57
 58
 59//------------------------------------------------------------------------------
 60// get_rows
 61//------------------------------------------------------------------------------
 62kernel void kernel_get_rows_f32(
 63        global void * src0,
 64        ulong offset0,
 65        global int * src1,
 66        ulong offset1,
 67        global float * dst,
 68        ulong offsetd,
 69        int ne00,
 70        ulong nb01,
 71        ulong nb02,
 72        ulong nb03,
 73        int ne10,
 74        ulong nb10,
 75        ulong nb11,
 76        ulong nb12,
 77        ulong nb1,
 78        ulong nb2,
 79        ulong nb3
 80) {
 81    src0 = (global void*)((global char*)src0 + offset0);
 82    src1 = (global int*)((global char*)src1 + offset1);
 83    dst = (global float*)((global char*)dst + offsetd);
 84
 85    int i10 = get_group_id(0);
 86    int i11 = get_group_id(1);
 87    int i12 = get_group_id(2);
 88
 89    int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
 90
 91    int i02 = i11;
 92    int i03 = i12;
 93
 94    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
 95        if (ind >= ne00) {
 96            return;
 97        }
 98        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
 99            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
100    }
101}
102
103kernel void kernel_get_rows_f16(
104        global void * src0,
105        ulong offset0,
106        global int * src1,
107        ulong offset1,
108        global float * dst,
109        ulong offsetd,
110        int ne00,
111        ulong nb01,
112        ulong nb02,
113        ulong nb03,
114        int ne10,
115        ulong nb10,
116        ulong nb11,
117        ulong nb12,
118        ulong nb1,
119        ulong nb2,
120        ulong nb3
121) {
122    src0 = (global void*)((global char*)src0 + offset0);
123    src1 = (global int*)((global char*)src1 + offset1);
124    dst = (global float*)((global char*)dst + offsetd);
125
126    int i10 = get_group_id(0);
127    int i11 = get_group_id(1);
128    int i12 = get_group_id(2);
129
130    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
131
132    int i02 = i11;
133    int i03 = i12;
134
135    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
136        if (ind >= ne00) {
137            return;
138        }
139        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
140            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
141    }
142}
143
144kernel void kernel_get_rows_q4_0(
145        global void * src0,
146        ulong offset0,
147        global int * src1,
148        ulong offset1,
149        global float * dst,
150        ulong offsetd,
151        int ne00,
152        ulong nb01,
153        ulong nb02,
154        ulong nb03,
155        int ne10,
156        ulong nb10,
157        ulong nb11,
158        ulong nb12,
159        ulong nb1,
160        ulong nb2,
161        ulong nb3
162) {
163    src0 = (global void*)((global char*)src0 + offset0);
164    src1 = (global int*)((global char*)src1 + offset1);
165    dst = (global float*)((global char*)dst + offsetd);
166
167    const int NL = 2;
168
169    int i10 = get_group_id(0);
170    int i11 = get_group_id(1);
171    int i12 = get_group_id(2);
172
173    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
174
175    int i02 = i11;
176    int i03 = i12;
177
178    for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
179        float16 temp;
180        if (ind >= ne00) {
181            return;
182        }
183        dequantize_q4_0_f32(
184            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
185        *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
186    }
187}