1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define QK_MXFP4 32
22
23static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
24 ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
25 fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
26 fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
27 fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
28 fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
29
30 bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
31 bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
32 bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
33 bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
34
35 fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
36 fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
37 fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
38 fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
39
40 sign_a.lo = (fp4x4 << 12) & 0x8000;
41 sign_a.hi = (fp4x4 << 8) & 0x8000;
42 sign_b.lo = (fp4x4 << 4) & 0x8000;
43 sign_b.hi = fp4x4 & 0x8000;
44
45 fp16_packed_a = sign_a + bias_a + fp16_packed_a;
46 fp16_packed_b = sign_b + bias_b + fp16_packed_b;
47
48 return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
49}
50
51static inline float e8m0_to_fp32(uchar x) {
52 int bits;
53 bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
54 return as_float(bits);
55}
56
57#ifdef INTEL_GPU
58#define N_R0_MXFP4 2 // number of rows each subgroup works on
59#define N_SG_MXFP4 2 // number of subgroups in a work group
60#define N_SIMDWIDTH 16 // subgroup size
61#elif defined (ADRENO_GPU)
62#define N_R0_MXFP4 2
63#define N_SG_MXFP4 2
64#define N_SIMDWIDTH 64
65#define SRC0Q_IMG
66#endif
67
68#ifdef INTEL_GPU
69REQD_SUBGROUP_SIZE_16
70#elif defined (ADRENO_GPU)
71REQD_SUBGROUP_SIZE_64
72#endif
73kernel void kernel_mul_mv_mxfp4_f32_flat(
74#ifdef SRC0Q_IMG
75 __read_only image1d_buffer_t src0_q,
76#else
77 global uchar * src0_q,
78#endif
79 global uchar * src0_e,
80 global uchar * src1,
81 ulong offset1,
82 global uchar * dst,
83 ulong offsetd,
84 int ne00,
85 ulong nb01,
86 ulong nb02,
87 ulong nb03,
88 int ne12,
89 ulong nb11,
90 ulong nb12,
91 ulong nb13,
92 int ne0,
93 int ne1,
94 int r2,
95 int r3
96) {
97 src1 = src1 + offset1;
98 dst = dst + offsetd;
99
100 int nb = ne00 / QK_MXFP4;
101
102 int r0 = get_group_id(0);
103 int r1 = get_group_id(1);
104 int im = get_group_id(2);
105
106 int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
107
108 uint i12 = im % ne12;
109 uint i13 = im / ne12;
110
111 uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
112 // 17 = sizeof(block_mxfp4)
113 offset_src0 /= 17;
114#ifdef SRC0Q_IMG
115 ulong offset_q = offset_src0;
116#else
117 global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
118#endif
119 global uchar * x_e = src0_e + offset_src0;
120
121 ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13;
122 global float * y = (global float *)(src1 + offset_src1);
123
124 const short ix = get_sub_group_local_id() >> 1; // 0...15
125 const short it = get_sub_group_local_id() & 1; // 0 or 1
126
127 float sumf[N_R0_MXFP4] = {0.f};
128
129 global float * yb = y + ix * QK_MXFP4 + it * 8;
130
131 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
132 global float4 * y4 = (global float4 *)yb;
133
134 #pragma unroll
135 for (short row = 0; row < N_R0_MXFP4; row++) {
136 uchar xb_e = x_e[row * nb + ib];
137#ifdef SRC0Q_IMG
138 ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
139#else
140 ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
141#endif
142
143 half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
144 half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
145 float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
146 acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
147
148 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
149 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
150 acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
151 acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
152
153 sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
154 }
155
156 yb += (N_SIMDWIDTH/2) * QK_MXFP4;
157 }
158
159 global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
160
161 for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
162 float sum_all = sub_group_reduce_add(sumf[row]);
163 if (get_sub_group_local_id() == 0) {
164 dst_f32[first_row + row] = sum_all;
165 }
166 }
167}