1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3typedef char int8_t;
4typedef uchar uint8_t;
5typedef short int16_t;
6typedef ushort uint16_t;
7typedef int int32_t;
8typedef uint uint32_t;
9
10#define QK4_0 32
11
12//------------------------------------------------------------------------------
13// block_q4_0
14//------------------------------------------------------------------------------
15struct block_q4_0
16{
17 half d;
18 uint8_t qs[QK4_0 / 2];
19};
20
21
22//------------------------------------------------------------------------------
23// dequantize_q4_0_f32, dequantize_q4_0_f16
24//------------------------------------------------------------------------------
25void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
26 global ushort * qs = ((global ushort *)xb + 1);
27 float d1 = il ? (xb->d / 16.h) : xb->d;
28 float d2 = d1 / 256.f;
29 float md = -8.h * xb->d;
30 ushort mask0 = il ? 0x00F0 : 0x000F;
31 ushort mask1 = mask0 << 8;
32
33 reg->s0 = d1 * (qs[0] & mask0) + md;
34 reg->s1 = d2 * (qs[0] & mask1) + md;
35
36 reg->s2 = d1 * (qs[1] & mask0) + md;
37 reg->s3 = d2 * (qs[1] & mask1) + md;
38
39 reg->s4 = d1 * (qs[2] & mask0) + md;
40 reg->s5 = d2 * (qs[2] & mask1) + md;
41
42 reg->s6 = d1 * (qs[3] & mask0) + md;
43 reg->s7 = d2 * (qs[3] & mask1) + md;
44
45 reg->s8 = d1 * (qs[4] & mask0) + md;
46 reg->s9 = d2 * (qs[4] & mask1) + md;
47
48 reg->sa = d1 * (qs[5] & mask0) + md;
49 reg->sb = d2 * (qs[5] & mask1) + md;
50
51 reg->sc = d1 * (qs[6] & mask0) + md;
52 reg->sd = d2 * (qs[6] & mask1) + md;
53
54 reg->se = d1 * (qs[7] & mask0) + md;
55 reg->sf = d2 * (qs[7] & mask1) + md;
56}
57
58
59//------------------------------------------------------------------------------
60// get_rows
61//------------------------------------------------------------------------------
62kernel void kernel_get_rows_f32(
63 global void * src0,
64 ulong offset0,
65 global int * src1,
66 ulong offset1,
67 global float * dst,
68 ulong offsetd,
69 int ne00,
70 ulong nb01,
71 ulong nb02,
72 ulong nb03,
73 int ne10,
74 ulong nb10,
75 ulong nb11,
76 ulong nb12,
77 ulong nb1,
78 ulong nb2,
79 ulong nb3
80) {
81 src0 = (global void*)((global char*)src0 + offset0);
82 src1 = (global int*)((global char*)src1 + offset1);
83 dst = (global float*)((global char*)dst + offsetd);
84
85 int i10 = get_group_id(0);
86 int i11 = get_group_id(1);
87 int i12 = get_group_id(2);
88
89 int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
90
91 int i02 = i11;
92 int i03 = i12;
93
94 for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
95 if (ind >= ne00) {
96 return;
97 }
98 ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
99 ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
100 }
101}
102
103kernel void kernel_get_rows_f16(
104 global void * src0,
105 ulong offset0,
106 global int * src1,
107 ulong offset1,
108 global float * dst,
109 ulong offsetd,
110 int ne00,
111 ulong nb01,
112 ulong nb02,
113 ulong nb03,
114 int ne10,
115 ulong nb10,
116 ulong nb11,
117 ulong nb12,
118 ulong nb1,
119 ulong nb2,
120 ulong nb3
121) {
122 src0 = (global void*)((global char*)src0 + offset0);
123 src1 = (global int*)((global char*)src1 + offset1);
124 dst = (global float*)((global char*)dst + offsetd);
125
126 int i10 = get_group_id(0);
127 int i11 = get_group_id(1);
128 int i12 = get_group_id(2);
129
130 int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
131
132 int i02 = i11;
133 int i03 = i12;
134
135 for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
136 if (ind >= ne00) {
137 return;
138 }
139 ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
140 ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
141 }
142}
143
144kernel void kernel_get_rows_q4_0(
145 global void * src0,
146 ulong offset0,
147 global int * src1,
148 ulong offset1,
149 global float * dst,
150 ulong offsetd,
151 int ne00,
152 ulong nb01,
153 ulong nb02,
154 ulong nb03,
155 int ne10,
156 ulong nb10,
157 ulong nb11,
158 ulong nb12,
159 ulong nb1,
160 ulong nb2,
161 ulong nb3
162) {
163 src0 = (global void*)((global char*)src0 + offset0);
164 src1 = (global int*)((global char*)src1 + offset1);
165 dst = (global float*)((global char*)dst + offsetd);
166
167 const int NL = 2;
168
169 int i10 = get_group_id(0);
170 int i11 = get_group_id(1);
171 int i12 = get_group_id(2);
172
173 int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
174
175 int i02 = i11;
176 int i03 = i12;
177
178 for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
179 float16 temp;
180 if (ind >= ne00) {
181 return;
182 }
183 dequantize_q4_0_f32(
184 ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
185 *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
186 }
187}