1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21//------------------------------------------------------------------------------
22// kernel_mul_mv_q6_K_f32_flat
23//------------------------------------------------------------------------------
24#define Q6_K_MASK1 0x03
25#define Q6_K_MASK2 0x0C
26#define Q6_K_MASK3 0x30
27#define Q6_K_MASK4 0xC0
28
29#define QK_K 256
30
31inline float block_q_6_K_dot_y_flat(
32 global uchar * blk_ql,
33 global uchar * blk_qh,
34 global char * blk_scales,
35 global half * blk_d,
36 global float * yy,
37 int ib,
38 int ip,
39 int is,
40 int l0
41) {
42 int y_offset = 128*ip + l0;
43 int q_offset_l = 64*ip + l0;
44 int q_offset_h = 32*ip + l0;
45
46 global uchar * q1 = blk_ql + ib*128 + q_offset_l;
47 global uchar * q2 = q1 + QK_K/8;
48 global uchar * qh = blk_qh + ib*64 + q_offset_h;
49 global char * sc = blk_scales + ib*16 + is;
50
51 global float * y = yy + ib * QK_K + y_offset;
52
53 float dall = blk_d[ib];
54
55 float sumf = 0;
56 float4 sums = {0.f, 0.f, 0.f, 0.f};
57
58 sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);
59 sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);
60 sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);
61 sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);
62
63 sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);
64 sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);
65 sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);
66 sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);
67
68 sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);
69 sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);
70 sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);
71 sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);
72
73 sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);
74 sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);
75 sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);
76 sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);
77
78 sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
79
80 return sumf;
81}
82
83#undef N_DST
84#undef N_SIMDGROUP
85#undef N_SIMDWIDTH
86
87#ifdef INTEL_GPU
88#define N_DST 4
89#define N_SIMDGROUP 2
90#define N_SIMDWIDTH 16
91#elif defined (ADRENO_GPU)
92#define N_DST 4
93#define N_SIMDGROUP 2
94#define N_SIMDWIDTH 64
95#endif
96
97#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
98
99#ifdef INTEL_GPU
100REQD_SUBGROUP_SIZE_16
101#elif defined (ADRENO_GPU)
102REQD_SUBGROUP_SIZE_64
103#endif
104kernel void kernel_mul_mv_q6_K_f32_flat(
105 global uchar * src0_ql,
106 global uchar * src0_qh,
107 global char * src0_s,
108 global half * src0_d,
109 global float * src1,
110 ulong offset1,
111 global float * dst,
112 ulong offsetd,
113 int ne00,
114 int ne01,
115 int ne02,
116 int ne10,
117 int ne12,
118 int ne0,
119 int ne1,
120 int r2,
121 int r3
122) {
123 src1 = (global float*)((global char*)src1 + offset1);
124 dst = (global float*)((global char*)dst + offsetd);
125
126 int nb = ne00/QK_K;
127
128 int r0 = get_group_id(0);
129 int r1 = get_group_id(1);
130 int im = get_group_id(2);
131
132 int i12 = im%ne12;
133 int i13 = im/ne12;
134
135 int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;
136
137 ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
138 ulong offset_src0_ql = offset_src0 * 128;
139 ulong offset_src0_qh = offset_src0 * 64;
140 ulong offset_src0_s = offset_src0 * 16;
141 ulong offset_src0_d = offset_src0;
142
143 global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql;
144 global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh;
145 global char * blk_scales = (global char *) src0_s + offset_src0_s;
146 global half * blk_d = (global half *) src0_d + offset_src0_d;
147 global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
148
149 int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
150 int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
151 int ip = tid/8; // first or second half of (super) block (0 or 1)
152 int il = tid%8; // each half has 8 parts, one per scale
153 int n = 4; // 4 scales at a time (and 4 sums)
154 int l0 = n*il; // offset into half-block, 0..28
155 int is = 8*ip + l0/16; // 0, 1, 8, 9
156
157 float4 sumf = 0;
158
159 for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
160 if (first_row + 0 < ne01) {
161 sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);
162 }
163 if (first_row + 1 < ne01) {
164 sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);
165 }
166 if (first_row + 2 < ne01) {
167 sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);
168 }
169 if (first_row + 3 < ne01) {
170 sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);
171 }
172 }
173
174 float4 tot = (float4)(
175 sub_group_reduce_add(sumf.s0),
176 sub_group_reduce_add(sumf.s1),
177 sub_group_reduce_add(sumf.s2),
178 sub_group_reduce_add(sumf.s3)
179 );
180 if (get_sub_group_local_id() == 0) {
181 if (first_row + 0 < ne01) {
182 dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
183 }
184 if (first_row + 1 < ne01) {
185 dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
186 }
187 if (first_row + 2 < ne01) {
188 dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
189 }
190 if (first_row + 3 < ne01) {
191 dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
192 }
193 }
194}