1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define QK_MXFP4 32
22
23static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
24 ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
25 fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
26 fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
27 fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
28 fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
29
30 bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
31 bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
32 bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
33 bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
34
35 fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
36 fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
37 fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
38 fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
39
40 sign_a.lo = (fp4x4 << 12) & 0x8000;
41 sign_a.hi = (fp4x4 << 8) & 0x8000;
42 sign_b.lo = (fp4x4 << 4) & 0x8000;
43 sign_b.hi = fp4x4 & 0x8000;
44
45 fp16_packed_a = sign_a + bias_a + fp16_packed_a;
46 fp16_packed_b = sign_b + bias_b + fp16_packed_b;
47
48 return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
49}
50
51static inline float e8m0_to_fp32(uchar x) {
52 int bits;
53 bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
54 return as_float(bits);
55}
56
57#ifdef INTEL_GPU
58#define N_R0_MXFP4 2 // number of rows each subgroup works on
59#define N_SG_MXFP4 2 // number of subgroups in a work group
60#define N_SIMDWIDTH 16 // subgroup size
61#elif defined (ADRENO_GPU)
62#define N_R0_MXFP4 4
63#define N_SG_MXFP4 1
64#define N_SIMDWIDTH 64
65#define SRC0Q_IMG
66#endif
67
68kernel void kernel_mul_mv_id_mxfp4_f32_flat(
69#ifdef SRC0Q_IMG
70 __read_only image1d_buffer_t src0_q,
71#else
72 global uchar * src0_q,
73#endif
74 global uchar * src0_e,
75 global uchar * src1,
76 ulong offset1,
77 global uchar * src2,
78 ulong offset2,
79 global uchar * dst,
80 ulong offsetd,
81 int ne00,
82 ulong nb01,
83 ulong nb02,
84 ulong nb03,
85 int ne11,
86 int ne12,
87 ulong nb11,
88 ulong nb12,
89 ulong nb13,
90 int ne20,
91 int ne21,
92 ulong nb21,
93 int ne0,
94 int ne1,
95 int r2,
96 int r3
97) {
98 dst = dst + offsetd;
99
100 const int iid1 = get_group_id(2) / ne20;
101 const int idx = get_group_id(2) % ne20;
102
103 uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx];
104
105 int i11 = idx % ne11;
106
107 int nb = ne00 / QK_MXFP4;
108
109 uint src0_off = i02*nb02;
110 src0_off /= 17; // 17 = sizeof(block_mxfp4)
111
112 src0_e = src0_e + src0_off;
113
114 dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);
115
116 int r0 = get_group_id(0);
117 int r1 = get_group_id(1);
118
119 int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
120
121 uint offset_src0 = first_row*nb01;
122 offset_src0 /= 17; // 17 = sizeof(block_mxfp4)
123#ifdef SRC0Q_IMG
124 ulong offset_q = src0_off + offset_src0;
125#else
126 src0_q = src0_q + src0_off*16;
127 global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
128#endif
129 global uchar * x_e = src0_e + offset_src0;
130
131 const short ix = get_sub_group_local_id() >> 1;
132 const short it = get_sub_group_local_id() & 1;
133
134 float sumf[N_R0_MXFP4] = {0.f};
135
136 src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12;
137 global float * y = (global float *) (src1 + r1 * nb11);
138 global float * yb = y + ix * QK_MXFP4 + it * 8;
139
140 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) {
141 global float4 * y4 = (global float4 *)yb;
142
143 #pragma unroll
144 for (short row = 0; row < N_R0_MXFP4; row++) {
145 uchar xb_e = x_e[row * nb + ib];
146#ifdef SRC0Q_IMG
147 ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
148#else
149 ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
150#endif
151
152 half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
153 half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
154 float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
155 acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
156
157 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
158 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
159 acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
160 acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
161
162 sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
163 }
164
165 yb += (N_SIMDWIDTH / 2) * QK_MXFP4;
166 }
167
168 global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0;
169
170 for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
171 float sum_all = sub_group_reduce_add(sumf[row]);
172 if (get_sub_group_local_id() == 0) {
173 dst_f32[first_row + row] = sum_all;
174 }
175 }
176}