1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  3#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  4
  5#define QK_MXFP4 32
  6#define N_SIMDGROUP 2
  7#define SIMDGROUP_WIDTH 64
  8
  9static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
 10    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
 11    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
 12    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
 13    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
 14    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
 15
 16    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
 17    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
 18    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
 19    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
 20
 21    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
 22    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
 23    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
 24    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
 25
 26    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
 27    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
 28    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
 29    sign_b.hi = fp4x8.s0 & 0x8000;
 30
 31    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
 32    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
 33
 34    ushort2 fp16_packed_a_1, fp16_packed_b_1;
 35    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
 36    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
 37    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
 38    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
 39
 40    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
 41    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
 42    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
 43    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
 44
 45    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
 46    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
 47    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
 48    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
 49
 50    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
 51    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
 52    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
 53    sign_b.hi = fp4x8.s1 & 0x8000;
 54
 55    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
 56    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
 57
 58    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
 59}
 60
 61static inline float e8m0_to_fp32(uchar x) {
 62    int bits;
 63    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
 64    return as_float(bits);
 65}
 66
 67
 68__attribute__((qcom_reqd_sub_group_size("half")))
 69__kernel void kernel_gemm_moe_mxfp4_f32(
 70    __global uint4 * src0_q,
 71    __global uchar * src0_e,
 72    __read_only image1d_buffer_t src1,
 73    __global ushort4 * src2,
 74    __global float * dst,
 75    ulong         offsetd,
 76    int           ne00,
 77    int           ne01,
 78    int           tile_size
 79) {
 80    uint i01  = get_global_id(0);
 81    uint i20  = get_global_id(2);
 82    uint sgid = get_local_id(1);
 83    uint slid = get_sub_group_local_id();
 84
 85    ushort4 router = src2[i20];
 86    ushort expert_id = router.x;
 87    ushort i11 = router.y;
 88    ushort i1 = router.z;
 89    ushort tile_id = router.w;
 90
 91    if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
 92        return;
 93    }
 94
 95    uint expert_offset = expert_id * ne00 * ne01 / 32;
 96    uint tile_offset = expert_offset + tile_id * tile_size + i01;
 97
 98    __private float sum = 0.0f; // each thread calculate partial sum of one output
 99
100    // loop along ne00 in block granularity, skip 4 blocks every iter
101    for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
102        // load one block of q
103        uint4 regQ = src0_q[tile_offset + ib00 * ne01];
104        // convert 8 fp4 to fp16
105        half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
106
107        uint offset = i11 * ne00 / 4 + ib00 * 8;
108        float4 shared_y4;
109        shared_y4 = read_imagef(src1, (offset + 0));
110        float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
111
112        shared_y4 = read_imagef(src1, (offset + 4));
113        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
114
115
116        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
117
118        shared_y4 = read_imagef(src1, (offset + 1));
119        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
120
121        shared_y4 = read_imagef(src1, (offset + 5));
122        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
123
124
125        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
126
127        shared_y4 = read_imagef(src1, (offset + 2));
128        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
129
130        shared_y4 = read_imagef(src1, (offset + 6));
131        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
132
133
134        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
135
136        shared_y4 = read_imagef(src1, (offset + 3));
137        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
138
139        shared_y4 = read_imagef(src1, (offset + 7));
140        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
141
142        uchar regE = src0_e[tile_offset + ib00 * ne01];
143        sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
144    }
145
146    // reduction in local memory, assumes #subgroups=4
147    __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
148    if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
149    // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
150    // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
151    barrier(CLK_LOCAL_MEM_FENCE);
152    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
153    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
154    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
155
156    // 1 outputs per thread in subgroup 0
157    if (sgid == 0) {
158        dst = dst + (offsetd >> 2);
159        dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
160    }
161
162}