1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#define QK_MXFP4 32
 22
 23static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
 24    ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
 25    fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
 26    fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
 27    fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
 28    fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
 29
 30    bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
 31    bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
 32    bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
 33    bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
 34
 35    fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
 36    fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
 37    fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
 38    fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
 39
 40    sign_a.lo = (fp4x4 << 12) & 0x8000;
 41    sign_a.hi = (fp4x4 << 8) & 0x8000;
 42    sign_b.lo = (fp4x4 << 4) & 0x8000;
 43    sign_b.hi = fp4x4 & 0x8000;
 44
 45    fp16_packed_a = sign_a + bias_a + fp16_packed_a;
 46    fp16_packed_b = sign_b + bias_b + fp16_packed_b;
 47
 48    return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
 49}
 50
 51static inline float e8m0_to_fp32(uchar x) {
 52    int bits;
 53    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
 54    return as_float(bits);
 55}
 56
 57#ifdef INTEL_GPU
 58#define N_R0_MXFP4 2 // number of rows each subgroup works on
 59#define N_SG_MXFP4 2 // number of subgroups in a work group
 60#define N_SIMDWIDTH 16 // subgroup size
 61#elif defined (ADRENO_GPU)
 62#define N_R0_MXFP4 2
 63#define N_SG_MXFP4 2
 64#define N_SIMDWIDTH 64
 65#define SRC0Q_IMG
 66#endif
 67
 68#ifdef INTEL_GPU
 69REQD_SUBGROUP_SIZE_16
 70#elif defined (ADRENO_GPU)
 71REQD_SUBGROUP_SIZE_64
 72#endif
 73kernel void kernel_mul_mv_mxfp4_f32_flat(
 74#ifdef SRC0Q_IMG
 75    __read_only image1d_buffer_t src0_q,
 76#else
 77    global uchar * src0_q,
 78#endif
 79    global uchar * src0_e,
 80    global uchar * src1,
 81    ulong          offset1,
 82    global uchar * dst,
 83    ulong          offsetd,
 84    int ne00,
 85    ulong nb01,
 86    ulong nb02,
 87    ulong nb03,
 88    int ne12,
 89    ulong nb11,
 90    ulong nb12,
 91    ulong nb13,
 92    int ne0,
 93    int ne1,
 94    int r2,
 95    int r3
 96) {
 97    src1 = src1 + offset1;
 98    dst = dst + offsetd;
 99
100    int nb = ne00 / QK_MXFP4;
101
102    int r0 = get_group_id(0);
103    int r1 = get_group_id(1);
104    int im = get_group_id(2);
105
106    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
107
108    uint i12 = im % ne12;
109    uint i13 = im / ne12;
110
111    uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
112    // 17 = sizeof(block_mxfp4)
113    offset_src0 /= 17;
114#ifdef SRC0Q_IMG
115    ulong offset_q = offset_src0;
116#else
117    global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
118#endif
119    global uchar * x_e = src0_e + offset_src0;
120
121    ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13;
122    global float * y = (global float *)(src1 + offset_src1);
123
124    const short ix = get_sub_group_local_id() >> 1;  // 0...15
125    const short it = get_sub_group_local_id() & 1;  // 0 or 1
126
127    float sumf[N_R0_MXFP4] = {0.f};
128
129    global float * yb = y + ix * QK_MXFP4 + it * 8;
130
131    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
132        global float4 * y4 = (global float4 *)yb;
133
134        #pragma unroll
135        for (short row = 0; row < N_R0_MXFP4; row++) {
136            uchar xb_e = x_e[row * nb + ib];
137#ifdef SRC0Q_IMG
138            ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
139#else
140            ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
141#endif
142
143            half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
144            half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
145            float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
146            acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
147
148            fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
149            fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
150            acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
151            acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
152
153            sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
154        }
155
156        yb += (N_SIMDWIDTH/2) * QK_MXFP4;
157    }
158
159    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
160
161    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
162        float sum_all = sub_group_reduce_add(sumf[row]);
163        if (get_sub_group_local_id() == 0) {
164            dst_f32[first_row + row] = sum_all;
165        }
166    }
167}