1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#define QK8_0 32
 22typedef struct {
 23    half d;       // delta
 24    char qs[QK8_0]; // quants
 25} block_q8_0;
 26
 27#define NB_Q8_0 8
 28
 29#ifdef INTEL_GPU
 30#define N_R0_Q8_0 4 // number of rows each subgroup works on
 31#define N_SG_Q8_0 2 // number of subgroups in a work group
 32#define N_SIMDWIDTH 16 // subgroup size
 33#elif defined (ADRENO_GPU)
 34#define N_R0_Q8_0 4
 35#define N_SG_Q8_0 2
 36#define N_SIMDWIDTH 64
 37#endif
 38
 39#ifdef INTEL_GPU
 40REQD_SUBGROUP_SIZE_16
 41#elif defined (ADRENO_GPU)
 42REQD_SUBGROUP_SIZE_64
 43#endif
 44kernel void kernel_mul_mv_q8_0_f32_flat(
 45    global char * src0_q,
 46    global half * src0_d,
 47    global char * src1,
 48    ulong         offset1,
 49    global char * dst,
 50    ulong         offsetd,
 51    int           ne00,
 52    int           ne01,
 53    ulong         nb01,
 54    ulong         nb02,
 55    ulong         nb03,
 56    int           ne12,
 57    ulong         nb11,
 58    ulong         nb12,
 59    ulong         nb13,
 60    int           ne0,
 61    int           ne1,
 62    int           r2,
 63    int           r3
 64) {
 65    src1 = (global char*)((global char*)src1 + offset1);
 66    dst  = (global char*)((global char*)dst  + offsetd);
 67
 68    int nb = ne00/QK8_0;
 69
 70    int r0 = get_group_id(0);
 71    int r1 = get_group_id(1);
 72    int im = get_group_id(2);
 73
 74    int first_row = (r0*N_SG_Q8_0 + get_sub_group_id()) * N_R0_Q8_0;
 75
 76    uint i12 = im%ne12;
 77    uint i13 = im/ne12;
 78
 79    ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
 80    global float * y  = (global float *) (src1 + offset_src1);
 81
 82    // pointers to src0 rows
 83    uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 84
 85    global char * ax0, * ax1, * ax2, * ax3;
 86    global half * ad0, * ad1, * ad2, * ad3;
 87    uint offset_src0;
 88
 89    offset_src0 = offset_src0_base + 0*nb01;
 90    offset_src0 = offset_src0/34;
 91    ax0 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
 92    ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
 93
 94    offset_src0 = offset_src0_base + 1*nb01;
 95    offset_src0 = offset_src0/34;
 96    ax1 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
 97    ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
 98
 99    offset_src0 = offset_src0_base + 2*nb01;
100    offset_src0 = offset_src0/34;
101    ax2 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
102    ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
103
104    offset_src0 = offset_src0_base + 3*nb01;
105    offset_src0 = offset_src0/34;
106    ax3 = (global char *) ((global char *) src0_q + offset_src0*sizeof(char)*QK8_0);
107    ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
108
109    const short ix = get_sub_group_local_id()/4;
110    const short il = get_sub_group_local_id()%4;
111
112    global float * yb = y + ix*QK8_0 + il*NB_Q8_0;
113
114    float8 yl;
115    float8 qv;
116    float4 sumf = 0.f;
117    float  sumq = 0.f;
118    global char * qs;
119
120    // each thread handles NB_Q8_0 quants at a time
121    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) {
122        yl = vload8(0, yb);
123
124        qs = ax0 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
125        qv = convert_float8(vload8(0, qs));
126        sumq = 0;
127        sumq += qv.s0*yl.s0;
128        sumq += qv.s1*yl.s1;
129        sumq += qv.s2*yl.s2;
130        sumq += qv.s3*yl.s3;
131        sumq += qv.s4*yl.s4;
132        sumq += qv.s5*yl.s5;
133        sumq += qv.s6*yl.s6;
134        sumq += qv.s7*yl.s7;
135        sumf.s0 += sumq*ad0[ib];
136
137        qs = ax1 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
138        qv = convert_float8(vload8(0, qs));
139        sumq = 0;
140        sumq += qv.s0*yl.s0;
141        sumq += qv.s1*yl.s1;
142        sumq += qv.s2*yl.s2;
143        sumq += qv.s3*yl.s3;
144        sumq += qv.s4*yl.s4;
145        sumq += qv.s5*yl.s5;
146        sumq += qv.s6*yl.s6;
147        sumq += qv.s7*yl.s7;
148        sumf.s1 += sumq*ad1[ib];
149
150        qs = ax2 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
151        qv = convert_float8(vload8(0, qs));
152        sumq = 0;
153        sumq += qv.s0*yl.s0;
154        sumq += qv.s1*yl.s1;
155        sumq += qv.s2*yl.s2;
156        sumq += qv.s3*yl.s3;
157        sumq += qv.s4*yl.s4;
158        sumq += qv.s5*yl.s5;
159        sumq += qv.s6*yl.s6;
160        sumq += qv.s7*yl.s7;
161        sumf.s2 += sumq*ad2[ib];
162
163        qs = ax3 + ib*sizeof(char)*QK8_0 + il*NB_Q8_0;
164        qv = convert_float8(vload8(0, qs));
165        sumq = 0;
166        sumq += qv.s0*yl.s0;
167        sumq += qv.s1*yl.s1;
168        sumq += qv.s2*yl.s2;
169        sumq += qv.s3*yl.s3;
170        sumq += qv.s4*yl.s4;
171        sumq += qv.s5*yl.s5;
172        sumq += qv.s6*yl.s6;
173        sumq += qv.s7*yl.s7;
174        sumf.s3 += sumq*ad3[ib];
175
176        yb += N_SIMDWIDTH*NB_Q8_0;
177    }
178
179    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
180
181    float4 tot = (float4)(
182        sub_group_reduce_add(sumf.s0),
183        sub_group_reduce_add(sumf.s1),
184        sub_group_reduce_add(sumf.s2),
185        sub_group_reduce_add(sumf.s3)
186    );
187
188    if (get_sub_group_local_id() == 0) {
189        if (first_row + 0 < ne01) {
190            dst_f32[first_row + 0] = tot.s0;
191        }
192        if (first_row + 1 < ne01) {
193            dst_f32[first_row + 1] = tot.s1;
194        }
195        if (first_row + 2 < ne01) {
196            dst_f32[first_row + 2] = tot.s2;
197        }
198        if (first_row + 3 < ne01) {
199            dst_f32[first_row + 3] = tot.s3;
200        }
201    }
202}