1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#define QK_MXFP4 32
 22
 23static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
 24    ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
 25    fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
 26    fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
 27    fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
 28    fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
 29
 30    bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
 31    bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
 32    bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
 33    bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
 34
 35    fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
 36    fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
 37    fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
 38    fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
 39
 40    sign_a.lo = (fp4x4 << 12) & 0x8000;
 41    sign_a.hi = (fp4x4 << 8) & 0x8000;
 42    sign_b.lo = (fp4x4 << 4) & 0x8000;
 43    sign_b.hi = fp4x4 & 0x8000;
 44
 45    fp16_packed_a = sign_a + bias_a + fp16_packed_a;
 46    fp16_packed_b = sign_b + bias_b + fp16_packed_b;
 47
 48    return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
 49}
 50
 51static inline float e8m0_to_fp32(uchar x) {
 52    int bits;
 53    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
 54    return as_float(bits);
 55}
 56
 57#ifdef INTEL_GPU
 58#define N_R0_MXFP4 2 // number of rows each subgroup works on
 59#define N_SG_MXFP4 2 // number of subgroups in a work group
 60#define N_SIMDWIDTH 16 // subgroup size
 61#elif defined (ADRENO_GPU)
 62#define N_R0_MXFP4 4
 63#define N_SG_MXFP4 1
 64#define N_SIMDWIDTH 64
 65#define SRC0Q_IMG
 66#endif
 67
 68kernel void kernel_mul_mv_id_mxfp4_f32_flat(
 69#ifdef SRC0Q_IMG
 70    __read_only image1d_buffer_t src0_q,
 71#else
 72    global uchar * src0_q,
 73#endif
 74    global uchar * src0_e,
 75    global uchar * src1,
 76    ulong         offset1,
 77    global uchar * src2,
 78    ulong         offset2,
 79    global uchar * dst,
 80    ulong         offsetd,
 81    int           ne00,
 82    ulong         nb01,
 83    ulong         nb02,
 84    ulong         nb03,
 85    int           ne11,
 86    int           ne12,
 87    ulong         nb11,
 88    ulong         nb12,
 89    ulong         nb13,
 90    int           ne20,
 91    int           ne21,
 92    ulong         nb21,
 93    int           ne0,
 94    int           ne1,
 95    int           r2,
 96    int           r3
 97) {
 98    dst  = dst  + offsetd;
 99
100    const int iid1 = get_group_id(2) / ne20;
101    const int idx  = get_group_id(2) % ne20;
102
103    uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx];
104
105    int i11 = idx % ne11;
106
107    int nb = ne00 / QK_MXFP4;
108
109    uint src0_off = i02*nb02;
110    src0_off /= 17; // 17 = sizeof(block_mxfp4)
111
112    src0_e = src0_e + src0_off;
113
114    dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);
115
116    int r0 = get_group_id(0);
117    int r1 = get_group_id(1);
118
119    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
120
121    uint offset_src0 = first_row*nb01;
122    offset_src0 /= 17; // 17 = sizeof(block_mxfp4)
123#ifdef SRC0Q_IMG
124    ulong offset_q = src0_off + offset_src0;
125#else
126    src0_q = src0_q + src0_off*16;
127    global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
128#endif
129    global uchar * x_e = src0_e + offset_src0;
130
131    const short ix = get_sub_group_local_id() >> 1;
132    const short it = get_sub_group_local_id() & 1;
133
134    float sumf[N_R0_MXFP4] = {0.f};
135
136    src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12;
137    global float * y   = (global float *) (src1 + r1 * nb11);
138    global float * yb = y + ix * QK_MXFP4 + it * 8;
139
140    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) {
141        global float4 * y4 = (global float4 *)yb;
142
143        #pragma unroll
144        for (short row = 0; row < N_R0_MXFP4; row++) {
145            uchar xb_e = x_e[row * nb + ib];
146#ifdef SRC0Q_IMG
147            ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
148#else
149            ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
150#endif
151
152            half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
153            half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
154            float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
155            acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
156
157            fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
158            fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
159            acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
160            acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
161
162            sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
163        }
164
165        yb += (N_SIMDWIDTH / 2) * QK_MXFP4;
166    }
167
168    global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0;
169
170    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
171        float sum_all = sub_group_reduce_add(sumf[row]);
172        if (get_sub_group_local_id() == 0) {
173            dst_f32[first_row + row] = sum_all;
174        }
175    }
176}