1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#define QK_MXFP4 32
 22typedef struct {
 23    uchar e; // E8M0
 24    uchar qs[QK_MXFP4/2];
 25} block_mxfp4;
 26
 27constant static float kvalues_mxfp4_f[16] = {
 28    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
 29};
 30
 31static inline float e8m0_to_fp32(uchar x) {
 32    int bits;
 33
 34    if (x == 0) {
 35        bits = 0x00400000;
 36    } else {
 37        bits = (uint) x << 23;
 38    }
 39
 40    return as_float(bits);
 41}
 42
 43#ifdef INTEL_GPU
 44#define N_R0_MXFP4 2 // number of rows each subgroup works on
 45#define N_SG_MXFP4 2 // number of subgroups in a work group
 46#define N_SIMDWIDTH 16 // subgroup size
 47#elif defined (ADRENO_GPU)
 48#define N_R0_MXFP4 2
 49#define N_SG_MXFP4 2
 50#define N_SIMDWIDTH 64
 51#endif
 52
 53inline void mul_mv_mxfp4_f32(
 54    global char * src0,
 55    global char * src1,
 56    global char * dst,
 57    int ne00,
 58    ulong nb01,
 59    ulong nb02,
 60    ulong nb03,
 61    int ne12,
 62    ulong nb11,
 63    ulong nb12,
 64    ulong nb13,
 65    int ne0,
 66    int ne1,
 67    int r2,
 68    int r3,
 69    local  char * shmem
 70) {
 71    local float * shmem_f32 = (local float *) shmem;
 72    int nb = ne00/QK_MXFP4;
 73
 74    int r0 = get_group_id(0);
 75    int r1 = get_group_id(1);
 76    int im = 0;
 77
 78    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
 79
 80    uint i12 = im%ne12;
 81    uint i13 = im/ne12;
 82
 83    ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 84    ulong offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 85
 86    global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
 87    global float       * y = (global float       *) (src1 + offset_src1);
 88
 89    const short ix = get_sub_group_local_id()/2;  // 0...15
 90    const short it = get_sub_group_local_id()%2;  // 0 or 1
 91
 92    shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
 93    barrier(CLK_LOCAL_MEM_FENCE);
 94
 95    float4 yl[4];
 96    float sumf[N_R0_MXFP4] = {0.f};
 97
 98    global float * yb = y + ix * QK_MXFP4 + it * 8;
 99
100    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
101        global float4 * y4 = (global float4 *)yb;
102        yl[0] = y4[0];
103        yl[1] = y4[4];
104        yl[2] = y4[1];
105        yl[3] = y4[5];
106
107        for (short row = 0; row < N_R0_MXFP4; row++) {
108            global block_mxfp4 * xb = x + row*nb + ib;
109            global uchar       * q2 = (global uchar *)(xb->qs + 8*it);
110
111            float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
112            float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
113            float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
114            float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
115
116            acc1 = (acc1 + acc3) + (acc2 + acc4);
117
118            sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
119        }
120
121        yb += (N_SIMDWIDTH/2) * QK_MXFP4;
122    }
123
124    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
125
126    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
127        float sum_all = sub_group_reduce_add(sumf[row]);
128        if (get_sub_group_local_id() == 0) {
129            dst_f32[first_row + row] = sum_all;
130        }
131    }
132}
133
134#ifdef INTEL_GPU
135REQD_SUBGROUP_SIZE_16
136#elif defined (ADRENO_GPU)
137REQD_SUBGROUP_SIZE_64
138#endif
139kernel void kernel_mul_mv_id_mxfp4_f32(
140    global char * src0,
141    ulong         offset0,
142    global char * src1,
143    ulong         offset1,
144    global char * src2,
145    ulong         offset2,
146    global char * dst,
147    ulong         offsetd,
148    int           ne00,
149    ulong         nb01,
150    ulong         nb02,
151    ulong         nb03,
152    int           ne11,
153    int           ne12,
154    ulong         nb11,
155    ulong         nb12,
156    ulong         nb13,
157    int           ne20,
158    int           ne21,
159    ulong         nb21,
160    int           ne0,
161    int           ne1,
162    int           r2,
163    int           r3,
164    local  char * shmem
165) {
166    src0 = (global char *)((global char *)src0 + offset0);
167    src1 = (global char *)((global char *)src1 + offset1);
168    src2 = (global char *)((global char *)src2 + offset2);
169    dst  = (global char *)((global char *)dst  + offsetd);
170
171    const int iid1 = get_group_id(2)/ne20;
172    const int idx  = get_group_id(2)%ne20;
173
174    int i02 = ((global int *) (src2 + iid1*nb21))[idx];
175
176    int i11 = idx % ne11;
177    int i12 = iid1;
178
179    int i1 = idx;
180    int i2 = i12;
181
182    global char * src0_cur = src0 + i02*nb02;
183    global char * src1_cur = src1 + i11*nb11 + i12*nb12;
184
185    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);
186
187    mul_mv_mxfp4_f32(src0_cur, src1_cur, dst_cur,
188        ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shmem);
189}