1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21//------------------------------------------------------------------------------
 22// kernel_mul_mv_q6_K_f32_flat
 23//------------------------------------------------------------------------------
 24#define Q6_K_MASK1 0x03
 25#define Q6_K_MASK2 0x0C
 26#define Q6_K_MASK3 0x30
 27#define Q6_K_MASK4 0xC0
 28
 29#define QK_K       256
 30
 31inline float block_q_6_K_dot_y_flat(
 32    global uchar * blk_ql,
 33    global uchar * blk_qh,
 34    global char  * blk_scales,
 35    global half  * blk_d,
 36    global float * yy,
 37    int ib,
 38    int ip,
 39    int is,
 40    int l0
 41) {
 42    int y_offset   = 128*ip + l0;
 43    int q_offset_l =  64*ip + l0;
 44    int q_offset_h =  32*ip + l0;
 45
 46    global uchar * q1 = blk_ql     + ib*128 + q_offset_l;
 47    global uchar * q2 = q1         + QK_K/8;
 48    global uchar * qh = blk_qh     + ib*64 + q_offset_h;
 49    global char  * sc = blk_scales + ib*16 + is;
 50
 51    global float * y = yy + ib * QK_K + y_offset;
 52
 53    float dall = blk_d[ib];
 54
 55    float  sumf = 0;
 56    float4 sums = {0.f, 0.f, 0.f, 0.f};
 57
 58    sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);
 59    sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);
 60    sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);
 61    sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);
 62
 63    sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);
 64    sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);
 65    sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);
 66    sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);
 67
 68    sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);
 69    sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);
 70    sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);
 71    sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);
 72
 73    sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);
 74    sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);
 75    sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);
 76    sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);
 77
 78    sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
 79
 80    return sumf;
 81}
 82
 83#undef N_DST
 84#undef N_SIMDGROUP
 85#undef N_SIMDWIDTH
 86
 87#ifdef INTEL_GPU
 88#define N_DST 4
 89#define N_SIMDGROUP 2
 90#define N_SIMDWIDTH 16
 91#elif defined (ADRENO_GPU)
 92#define N_DST 4
 93#define N_SIMDGROUP 2
 94#define N_SIMDWIDTH 64
 95#endif
 96
 97#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
 98
 99#ifdef INTEL_GPU
100REQD_SUBGROUP_SIZE_16
101#elif defined (ADRENO_GPU)
102REQD_SUBGROUP_SIZE_64
103#endif
104kernel void kernel_mul_mv_q6_K_f32_flat(
105        global uchar * src0_ql,
106        global uchar * src0_qh,
107        global char  * src0_s,
108        global half  * src0_d,
109        global float * src1,
110        ulong offset1,
111        global float * dst,
112        ulong offsetd,
113        int ne00,
114        int ne01,
115        int ne02,
116        int ne10,
117        int ne12,
118        int ne0,
119        int ne1,
120        int r2,
121        int r3
122) {
123    src1 = (global float*)((global char*)src1 + offset1);
124    dst = (global float*)((global char*)dst + offsetd);
125
126    int nb = ne00/QK_K;
127
128    int r0 = get_group_id(0);
129    int r1 = get_group_id(1);
130    int im = get_group_id(2);
131
132    int i12 = im%ne12;
133    int i13 = im/ne12;
134
135    int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;
136
137    ulong offset_src0    = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
138    ulong offset_src0_ql = offset_src0 * 128;
139    ulong offset_src0_qh = offset_src0 * 64;
140    ulong offset_src0_s  = offset_src0 * 16;
141    ulong offset_src0_d  = offset_src0;
142
143    global uchar * blk_ql     = (global uchar *) src0_ql + offset_src0_ql;
144    global uchar * blk_qh     = (global uchar *) src0_qh + offset_src0_qh;
145    global char  * blk_scales = (global char  *) src0_s  + offset_src0_s;
146    global half  * blk_d      = (global half  *) src0_d  + offset_src0_d;
147    global float * yy         = (global float *) src1    + r1*ne10 + im*ne00*ne1;
148
149    int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
150    int ix  = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
151    int ip  = tid/8;   // first or second half of (super) block (0 or 1)
152    int il  = tid%8;   // each half has 8 parts, one per scale
153    int n   = 4;       // 4 scales at a time (and 4 sums)
154    int l0  = n*il;    // offset into half-block, 0..28
155    int is  = 8*ip + l0/16; // 0, 1, 8, 9
156
157    float4 sumf = 0;
158
159    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
160        if (first_row + 0 < ne01) {
161            sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);
162        }
163        if (first_row + 1 < ne01) {
164            sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);
165        }
166        if (first_row + 2 < ne01) {
167            sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);
168        }
169        if (first_row + 3 < ne01) {
170            sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);
171        }
172    }
173
174    float4 tot = (float4)(
175        sub_group_reduce_add(sumf.s0),
176        sub_group_reduce_add(sumf.s1),
177        sub_group_reduce_add(sumf.s2),
178        sub_group_reduce_add(sumf.s3)
179    );
180    if (get_sub_group_local_id() == 0) {
181        if (first_row + 0 < ne01) {
182            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
183        }
184        if (first_row + 1 < ne01) {
185            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
186        }
187        if (first_row + 2 < ne01) {
188            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
189        }
190        if (first_row + 3 < ne01) {
191            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
192        }
193    }
194}