1#ifdef cl_intel_required_subgroup_size
  2#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
  3#define INTEL_GPU 1
  4#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
  5#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
  6#elif defined(cl_qcom_reqd_sub_group_size)
  7#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  8#define ADRENO_GPU 1
  9#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 10#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 11#endif
 12
 13//------------------------------------------------------------------------------
 14// block_q4_K
 15//------------------------------------------------------------------------------
 16#define QK_K            256
 17#define K_SCALE_SIZE    12
 18
 19// 8 blocks of 32 elements each
 20// weight is represented as x = a * q + b
 21typedef struct {
 22    half d;    // super-block scale for quantized scales
 23    half dmin; // super-block scale for quantized mins
 24
 25    uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
 26    uchar qs[QK_K/2];           // 4-bit quants
 27} block_q4_K;
 28
 29#undef N_DST
 30#undef N_SIMDGROUP
 31#undef N_SIMDWIDTH
 32
 33#ifdef INTEL_GPU
 34#define N_DST 4 // number of rows each SIMD group works on
 35#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
 36#define N_SIMDWIDTH 16 // SIMD group size
 37#elif defined (ADRENO_GPU)
 38#define N_DST 4
 39#define N_SIMDGROUP 1
 40#define N_SIMDWIDTH 64
 41#endif
 42
 43#undef  BLOCK_STRIDE
 44// number of (super) blocks each subgroup processes
 45// each thread in a subgroup processes a block (32 weights)
 46#define BLOCK_STRIDE (N_SIMDWIDTH/8)
 47
 48#ifdef INTEL_GPU
 49REQD_SUBGROUP_SIZE_16
 50#elif defined (ADRENO_GPU)
 51REQD_SUBGROUP_SIZE_64
 52#endif
 53kernel void kernel_mul_mv_q4_K_f32(
 54        global char * src0,
 55        int offset0,
 56        global char * src1,
 57        int offset1,
 58        global char * dst,
 59        int offsetd,
 60        int ne00,
 61        int ne01,
 62        ulong nb01,
 63        ulong nb02,
 64        ulong nb03,
 65        int ne12,
 66        ulong nb11,
 67        ulong nb12,
 68        ulong nb13,
 69        int ne0,
 70        int ne1,
 71        int r2,
 72        int r3
 73) {
 74    src0 = src0 + offset0;
 75    src1 = src1 + offset1;
 76    dst  = dst  + offsetd;
 77
 78    ushort kmask1 = 0x3f3f;
 79    ushort kmask2 = 0x0f0f;
 80    ushort kmask3 = 0xc0c0;
 81
 82    int ix = get_sub_group_local_id()/8;  // super block index
 83    int it = get_sub_group_local_id()%8;  // block index (inside super block)
 84    int iq = it/4;     // 0 or 1 - first or second half of the super block
 85    int ir = it%4;     // 0...3 - block index in the half super block
 86
 87    int nb = ne00/QK_K;
 88
 89    int r0 = get_group_id(0);
 90    int r1 = get_group_id(1);
 91    int im = get_group_id(2);
 92    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
 93
 94    int i12 = im%ne12;
 95    int i13 = im/ne12;
 96
 97    int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 98    int offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 99
100    global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
101    global float      * y = (global float      *) (src1 + offset_src1);
102
103    float yl[16];
104    float yh[16];
105    float sumf[N_DST] = {0.f};
106    float all_sum;
107
108    global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
109
110    ushort  sc16[4];
111    uchar * sc8 = (uchar *)sc16;
112
113    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
114        float4 sumy = {0.f, 0.f, 0.f, 0.f};
115        for (int i = 0; i < 8; ++i) {
116            yl[i+0] = y4[i+0];
117            sumy.s0 += yl[i+0];
118
119            yl[i+8] = y4[i+32];
120            sumy.s1 += yl[i+8];
121
122            yh[i+0] = y4[i+128];
123            sumy.s2 += yh[i+0];
124
125            yh[i+8] = y4[i+160];
126            sumy.s3 += yh[i+8];
127        }
128
129        global ushort * sc = (global ushort *)x[ib].scales + iq;
130        global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
131        global half     * dh = &x[ib].d;
132
133        for (int row = 0; row < N_DST; row++) {
134            sc16[0] = sc[0] & kmask1;
135            sc16[1] = sc[2] & kmask1;
136            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
137            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
138
139            global ushort * q2 = q1 + 32;
140
141            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
142            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
143            for (int i = 0; i < 8; i += 2) {
144                acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
145                acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
146                acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
147                acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
148                acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
149                acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
150                acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
151                acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
152            }
153
154            float dall = dh[0];
155            float dmin = dh[1];
156            sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
157                                 (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
158                                 (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
159                                 (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
160                         dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
161
162            q1 += nb01/2;
163            sc += nb01/2;
164            dh += nb01/2;
165        }
166
167        y4 += BLOCK_STRIDE * QK_K;
168    }
169
170    global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
171
172    for (int row = 0; row < N_DST; ++row) {
173        all_sum = sub_group_reduce_add(sumf[row]);
174        if (first_row + row < ne01) {
175            if (get_sub_group_local_id() == 0) {
176                dst_f32[first_row + row] = all_sum;
177            }
178        }
179    }
180}