1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#define QK8_0 32
 22typedef struct {
 23    half d;       // delta
 24    char qs[QK8_0]; // quants
 25} block_q8_0;
 26
 27#define NB_Q8_0 8
 28
 29#ifdef INTEL_GPU
 30#define N_R0_Q8_0 4 // number of rows each subgroup works on
 31#define N_SG_Q8_0 2 // number of subgroups in a work group
 32#define N_SIMDWIDTH 16 // subgroup size
 33#elif defined (ADRENO_GPU)
 34#define N_R0_Q8_0 4
 35#define N_SG_Q8_0 2
 36#define N_SIMDWIDTH 64
 37#endif
 38
 39#ifdef INTEL_GPU
 40REQD_SUBGROUP_SIZE_16
 41#elif defined (ADRENO_GPU)
 42REQD_SUBGROUP_SIZE_64
 43#endif
 44kernel void kernel_mul_mv_id_q8_0_f32(
 45    global char * src0,
 46    ulong         offset0,
 47    global char * src1,
 48    ulong         offset1,
 49    global char * src2,
 50    ulong         offset2,
 51    global char * dst,
 52    ulong         offsetd,
 53    int           ne00,
 54    int           ne01,
 55    ulong         nb01,
 56    ulong         nb02,
 57    int           ne11,
 58    int           ne12,
 59    ulong         nb11,
 60    ulong         nb12,
 61    int           ne20,
 62    int           ne21,
 63    ulong         nb21,
 64    int           ne0,
 65    int           ne1
 66) {
 67    src0 = (global char *)((global char *)src0 + offset0);
 68    src1 = (global char *)((global char *)src1 + offset1);
 69    src2 = (global char *)((global char *)src2 + offset2);
 70    dst  = (global char *)((global char *)dst  + offsetd);
 71
 72    int iid1 = get_group_id(2)/ne20;
 73    int idx  = get_group_id(2)%ne20;
 74
 75    int i02 = ((global int *) (src2 + iid1*nb21))[idx];
 76
 77    int i11_ = idx % ne11;
 78    int i12_ = iid1;
 79
 80    int i1 = idx;
 81    int i2 = i12_;
 82
 83    global char * src0_cur = src0 + i02*nb02;
 84    global char * src1_cur = src1 + i11_*nb11 + i12_*nb12;
 85
 86    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);
 87
 88    int nb = ne00/QK8_0;
 89
 90    int r0 = get_group_id(0);
 91    int r1 = get_group_id(1);
 92
 93    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;
 94
 95    ulong offset_src1 = r1*nb11;
 96    global float * y  = (global float *) (src1_cur + offset_src1);
 97
 98    // pointers to src0 rows
 99    global block_q8_0 * ax[N_R0_Q8_0];
100    for (int row = 0; row < N_R0_Q8_0; ++row) {
101        ulong offset_src0 = (first_row + row)*nb01;
102        ax[row] = (global block_q8_0 *) ((global char *) src0_cur + offset_src0);
103    }
104
105    float yl[NB_Q8_0];
106    float sumf[N_R0_Q8_0] = { 0.f };
107
108    const short ix = get_sub_group_local_id()/4;
109    const short il = get_sub_group_local_id()%4;
110
111    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;
112
113    // each thread handles NB_Q8_0 quants at a time
114    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {
115        for (short i = 0; i < NB_Q8_0; ++i) {
116            yl[i] = yb[i];
117        }
118
119        for (short row = 0; row < N_R0_Q8_0; row++) {
120            global char * qs = ax[row][ib].qs + il*NB_Q8_0;
121            float sumq = 0.f;
122            for (short iq = 0; iq < NB_Q8_0; ++iq) {
123                sumq += qs[iq] * yl[iq];
124            }
125            sumf[row] += sumq*ax[row][ib].d;
126        }
127
128        yb += N_SIMDWIDTH*NB_Q8_0;
129    }
130
131    global float * dst_f32 = (global float *) dst_cur + (ulong)r1*ne0;
132
133    for (int row = 0; row < N_R0_Q8_0; ++row) {
134        float tot = sub_group_reduce_add(sumf[row]);
135
136        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
137            dst_f32[first_row + row] = tot;
138        }
139    }
140}