aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl144
1 files changed, 144 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
new file mode 100644
index 0000000..9a4d4b9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
@@ -0,0 +1,144 @@
1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define QK_MXFP4 32
22typedef struct {
23 uchar e; // E8M0
24 uchar qs[QK_MXFP4/2];
25} block_mxfp4;
26
27constant static float kvalues_mxfp4_f[16] = {
28 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
29};
30
31static inline float e8m0_to_fp32(uchar x) {
32 int bits;
33
34 if (x == 0) {
35 bits = 0x00400000;
36 } else {
37 bits = (uint) x << 23;
38 }
39
40 return as_float(bits);
41}
42
43#ifdef INTEL_GPU
44#define N_R0_MXFP4 2 // number of rows each subgroup works on
45#define N_SG_MXFP4 2 // number of subgroups in a work group
46#define N_SIMDWIDTH 16 // subgroup size
47#elif defined (ADRENO_GPU)
48#define N_R0_MXFP4 2
49#define N_SG_MXFP4 2
50#define N_SIMDWIDTH 64
51#endif
52
53#ifdef INTEL_GPU
54REQD_SUBGROUP_SIZE_16
55#elif defined (ADRENO_GPU)
56REQD_SUBGROUP_SIZE_64
57#endif
58kernel void kernel_mul_mv_mxfp4_f32(
59 global char * src0,
60 ulong offset0,
61 global char * src1,
62 ulong offset1,
63 global char * dst,
64 ulong offsetd,
65 int ne00,
66 ulong nb01,
67 ulong nb02,
68 ulong nb03,
69 int ne12,
70 ulong nb11,
71 ulong nb12,
72 ulong nb13,
73 int ne0,
74 int ne1,
75 int r2,
76 int r3,
77 local char * shmem
78) {
79 src0 = (global char*)((global char*)src0 + offset0);
80 src1 = (global char*)((global char*)src1 + offset1);
81 dst = (global char*)((global char*)dst + offsetd);
82
83 local float * shmem_f32 = (local float *) shmem;
84 int nb = ne00/QK_MXFP4;
85
86 int r0 = get_group_id(0);
87 int r1 = get_group_id(1);
88 int im = get_group_id(2);
89
90 int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
91
92 uint i12 = im%ne12;
93 uint i13 = im/ne12;
94
95 ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
96 ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
97
98 global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
99 global float * y = (global float *) (src1 + offset_src1);
100
101 const short ix = get_sub_group_local_id()/2; // 0...15
102 const short it = get_sub_group_local_id()%2; // 0 or 1
103
104 shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
105 barrier(CLK_LOCAL_MEM_FENCE);
106
107 float4 yl[4];
108 float sumf[N_R0_MXFP4] = {0.f};
109
110 global float * yb = y + ix * QK_MXFP4 + it * 8;
111
112 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
113 global float4 * y4 = (global float4 *)yb;
114 yl[0] = y4[0];
115 yl[1] = y4[4];
116 yl[2] = y4[1];
117 yl[3] = y4[5];
118
119 for (short row = 0; row < N_R0_MXFP4; row++) {
120 global block_mxfp4 * xb = x + row*nb + ib;
121 global uchar * q2 = (global uchar *)(xb->qs + 8*it);
122
123 float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
124 float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
125 float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
126 float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
127
128 acc1 = (acc1 + acc3) + (acc2 + acc4);
129
130 sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
131 }
132
133 yb += (N_SIMDWIDTH/2) * QK_MXFP4;
134 }
135
136 global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
137
138 for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
139 float sum_all = sub_group_reduce_add(sumf[row]);
140 if (get_sub_group_local_id() == 0) {
141 dst_f32[first_row + row] = sum_all;
142 }
143 }
144}