1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define QK8_0 32
22typedef struct {
23 half d; // delta
24 char qs[QK8_0]; // quants
25} block_q8_0;
26
27#define NB_Q8_0 8
28
29#ifdef INTEL_GPU
30#define N_R0_Q8_0 4 // number of rows each subgroup works on
31#define N_SG_Q8_0 2 // number of subgroups in a work group
32#define N_SIMDWIDTH 16 // subgroup size
33#elif defined (ADRENO_GPU)
34#define N_R0_Q8_0 4
35#define N_SG_Q8_0 2
36#define N_SIMDWIDTH 64
37#endif
38
39#ifdef INTEL_GPU
40REQD_SUBGROUP_SIZE_16
41#elif defined (ADRENO_GPU)
42REQD_SUBGROUP_SIZE_64
43#endif
44kernel void kernel_mul_mv_q8_0_f32_flat(
45 global char * src0_q,
46 global half * src0_d,
47 global char * src1,
48 ulong offset1,
49 global char * dst,
50 ulong offsetd,
51 int ne00,
52 int ne01,
53 ulong nb01,
54 ulong nb02,
55 ulong nb03,
56 int ne12,
57 ulong nb11,
58 ulong nb12,
59 ulong nb13,
60 int ne0,
61 int ne1,
62 int r2,
63 int r3
64) {
65 src1 = (global char*)((global char*)src1 + offset1);
66 dst = (global char*)((global char*)dst + offsetd);
67
68 int nb = ne00/QK8_0;
69
70 int r0 = get_group_id(0);
71 int r1 = get_group_id(1);
72 int im = get_group_id(2);
73
74 int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;
75
76 uint i12 = im%ne12;
77 uint i13 = im/ne12;
78
79 ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
80 global float * y = (global float *) (src1 + offset_src1);
81
82 // pointers to src0 rows
83 uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
84
85 global char * ax0, * ax1, * ax2, * ax3;
86 global half * ad0, * ad1, * ad2, * ad3;
87 uint offset_src0;
88
89 offset_src0 = offset_src0_base + 0*nb01;
90 offset_src0 = offset_src0/34;
91 ax0 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
92 ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
93
94 offset_src0 = offset_src0_base + 1*nb01;
95 offset_src0 = offset_src0/34;
96 ax1 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
97 ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
98
99 offset_src0 = offset_src0_base + 2*nb01;
100 offset_src0 = offset_src0/34;
101 ax2 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
102 ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
103
104 offset_src0 = offset_src0_base + 3*nb01;
105 offset_src0 = offset_src0/34;
106 ax3 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
107 ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
108
109 const short ix = get_sub_group_local_id()/4;
110 const short il = get_sub_group_local_id()%4;
111
112 global float * yb = y + ix*QK8_0 + il*NB_Q8_0;
113
114 float8 yl;
115 float8 qv;
116 float4 sumf = 0.f;
117 float sumq = 0.f;
118 global char * qs;
119
120 // each thread handles NB_Q8_0 quants at a time
121 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {
122 yl = vload8(0, yb);
123
124 qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
125 qv = convert_float8(vload8(0, qs));
126 sumq = 0;
127 sumq += qv.s0*yl.s0;
128 sumq += qv.s1*yl.s1;
129 sumq += qv.s2*yl.s2;
130 sumq += qv.s3*yl.s3;
131 sumq += qv.s4*yl.s4;
132 sumq += qv.s5*yl.s5;
133 sumq += qv.s6*yl.s6;
134 sumq += qv.s7*yl.s7;
135 sumf.s0 += sumq*ad0[ib];
136
137 qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
138 qv = convert_float8(vload8(0, qs));
139 sumq = 0;
140 sumq += qv.s0*yl.s0;
141 sumq += qv.s1*yl.s1;
142 sumq += qv.s2*yl.s2;
143 sumq += qv.s3*yl.s3;
144 sumq += qv.s4*yl.s4;
145 sumq += qv.s5*yl.s5;
146 sumq += qv.s6*yl.s6;
147 sumq += qv.s7*yl.s7;
148 sumf.s1 += sumq*ad1[ib];
149
150 qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
151 qv = convert_float8(vload8(0, qs));
152 sumq = 0;
153 sumq += qv.s0*yl.s0;
154 sumq += qv.s1*yl.s1;
155 sumq += qv.s2*yl.s2;
156 sumq += qv.s3*yl.s3;
157 sumq += qv.s4*yl.s4;
158 sumq += qv.s5*yl.s5;
159 sumq += qv.s6*yl.s6;
160 sumq += qv.s7*yl.s7;
161 sumf.s2 += sumq*ad2[ib];
162
163 qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
164 qv = convert_float8(vload8(0, qs));
165 sumq = 0;
166 sumq += qv.s0*yl.s0;
167 sumq += qv.s1*yl.s1;
168 sumq += qv.s2*yl.s2;
169 sumq += qv.s3*yl.s3;
170 sumq += qv.s4*yl.s4;
171 sumq += qv.s5*yl.s5;
172 sumq += qv.s6*yl.s6;
173 sumq += qv.s7*yl.s7;
174 sumf.s3 += sumq*ad3[ib];
175
176 yb += N_SIMDWIDTH*NB_Q8_0;
177 }
178
179 global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
180
181 float4 tot = (float4)(
182 sub_group_reduce_add(sumf.s0),
183 sub_group_reduce_add(sumf.s1),
184 sub_group_reduce_add(sumf.s2),
185 sub_group_reduce_add(sumf.s3)
186 );
187
188 if (get_sub_group_local_id() == 0) {
189 if (first_row + 0 < ne01) {
190 dst_f32[first_row + 0] = tot.s0;
191 }
192 if (first_row + 1 < ne01) {
193 dst_f32[first_row + 1] = tot.s1;
194 }
195 if (first_row + 2 < ne01) {
196 dst_f32[first_row + 2] = tot.s2;
197 }
198 if (first_row + 3 < ne01) {
199 dst_f32[first_row + 3] = tot.s3;
200 }
201 }
202}