1#ifdef cl_intel_required_subgroup_size
2#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
3#define INTEL_GPU 1
4#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
5#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
6#elif defined(cl_qcom_reqd_sub_group_size)
7#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
8#define ADRENO_GPU 1
9#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
10#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
11#endif
12
13//------------------------------------------------------------------------------
14// block_q4_K
15//------------------------------------------------------------------------------
16#define QK_K 256
17#define K_SCALE_SIZE 12
18
19// 8 blocks of 32 elements each
20// weight is represented as x = a * q + b
21typedef struct {
22 half d; // super-block scale for quantized scales
23 half dmin; // super-block scale for quantized mins
24
25 uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
26 uchar qs[QK_K/2]; // 4-bit quants
27} block_q4_K;
28
29#undef N_DST
30#undef N_SIMDGROUP
31#undef N_SIMDWIDTH
32
33#ifdef INTEL_GPU
34#define N_DST 4 // number of rows each SIMD group works on
35#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
36#define N_SIMDWIDTH 16 // SIMD group size
37#elif defined (ADRENO_GPU)
38#define N_DST 4
39#define N_SIMDGROUP 1
40#define N_SIMDWIDTH 64
41#endif
42
43#undef BLOCK_STRIDE
44// number of (super) blocks each subgroup processes
45// each thread in a subgroup processes a block (32 weights)
46#define BLOCK_STRIDE (N_SIMDWIDTH/8)
47
48#ifdef INTEL_GPU
49REQD_SUBGROUP_SIZE_16
50#elif defined (ADRENO_GPU)
51REQD_SUBGROUP_SIZE_64
52#endif
53kernel void kernel_mul_mv_q4_K_f32(
54 global char * src0,
55 int offset0,
56 global char * src1,
57 int offset1,
58 global char * dst,
59 int offsetd,
60 int ne00,
61 int ne01,
62 ulong nb01,
63 ulong nb02,
64 ulong nb03,
65 int ne12,
66 ulong nb11,
67 ulong nb12,
68 ulong nb13,
69 int ne0,
70 int ne1,
71 int r2,
72 int r3
73) {
74 src0 = src0 + offset0;
75 src1 = src1 + offset1;
76 dst = dst + offsetd;
77
78 ushort kmask1 = 0x3f3f;
79 ushort kmask2 = 0x0f0f;
80 ushort kmask3 = 0xc0c0;
81
82 int ix = get_sub_group_local_id()/8; // super block index
83 int it = get_sub_group_local_id()%8; // block index (inside super block)
84 int iq = it/4; // 0 or 1 - first or second half of the super block
85 int ir = it%4; // 0...3 - block index in the half super block
86
87 int nb = ne00/QK_K;
88
89 int r0 = get_group_id(0);
90 int r1 = get_group_id(1);
91 int im = get_group_id(2);
92 int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
93
94 int i12 = im%ne12;
95 int i13 = im/ne12;
96
97 int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
98 int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
99
100 global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
101 global float * y = (global float *) (src1 + offset_src1);
102
103 float yl[16];
104 float yh[16];
105 float sumf[N_DST] = {0.f};
106 float all_sum;
107
108 global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
109
110 ushort sc16[4];
111 uchar * sc8 = (uchar *)sc16;
112
113 for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
114 float4 sumy = {0.f, 0.f, 0.f, 0.f};
115 for (int i = 0; i < 8; ++i) {
116 yl[i+0] = y4[i+0];
117 sumy.s0 += yl[i+0];
118
119 yl[i+8] = y4[i+32];
120 sumy.s1 += yl[i+8];
121
122 yh[i+0] = y4[i+128];
123 sumy.s2 += yh[i+0];
124
125 yh[i+8] = y4[i+160];
126 sumy.s3 += yh[i+8];
127 }
128
129 global ushort * sc = (global ushort *)x[ib].scales + iq;
130 global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
131 global half * dh = &x[ib].d;
132
133 for (int row = 0; row < N_DST; row++) {
134 sc16[0] = sc[0] & kmask1;
135 sc16[1] = sc[2] & kmask1;
136 sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
137 sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
138
139 global ushort * q2 = q1 + 32;
140
141 float4 acc1 = {0.f, 0.f, 0.f, 0.f};
142 float4 acc2 = {0.f, 0.f, 0.f, 0.f};
143 for (int i = 0; i < 8; i += 2) {
144 acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
145 acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
146 acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
147 acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
148 acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
149 acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
150 acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
151 acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
152 }
153
154 float dall = dh[0];
155 float dmin = dh[1];
156 sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
157 (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
158 (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
159 (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
160 dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
161
162 q1 += nb01/2;
163 sc += nb01/2;
164 dh += nb01/2;
165 }
166
167 y4 += BLOCK_STRIDE * QK_K;
168 }
169
170 global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
171
172 for (int row = 0; row < N_DST; ++row) {
173 all_sum = sub_group_reduce_add(sumf[row]);
174 if (first_row + row < ne01) {
175 if (get_sub_group_local_id() == 0) {
176 dst_f32[first_row + row] = all_sum;
177 }
178 }
179 }
180}