summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl162
1 files changed, 162 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl
new file mode 100644
index 0000000..3917aa3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl
@@ -0,0 +1,162 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+#define QK_MXFP4 32
+#define N_SIMDGROUP 2
+#define SIMDGROUP_WIDTH 64
+
+static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
+ ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
+ fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
+ fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
+ fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
+ fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
+
+ bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
+ bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
+ bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
+ bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
+
+ fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
+ fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
+ fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
+ fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
+
+ sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
+ sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
+ sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
+ sign_b.hi = fp4x8.s0 & 0x8000;
+
+ fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
+ fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
+
+ ushort2 fp16_packed_a_1, fp16_packed_b_1;
+ fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
+ fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
+ fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
+ fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
+
+ bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
+ bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
+ bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
+ bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
+
+ fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
+ fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
+ fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
+ fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
+
+ sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
+ sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
+ sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
+ sign_b.hi = fp4x8.s1 & 0x8000;
+
+ fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
+ fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
+
+ return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
+}
+
+static inline float e8m0_to_fp32(uchar x) {
+ int bits;
+ bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
+ return as_float(bits);
+}
+
+
+__attribute__((qcom_reqd_sub_group_size("half")))
+__kernel void kernel_gemm_moe_mxfp4_f32(
+ __global uint4 * src0_q,
+ __global uchar * src0_e,
+ __read_only image1d_buffer_t src1,
+ __global ushort4 * src2,
+ __global float * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int tile_size
+) {
+ uint i01 = get_global_id(0);
+ uint i20 = get_global_id(2);
+ uint sgid = get_local_id(1);
+ uint slid = get_sub_group_local_id();
+
+ ushort4 router = src2[i20];
+ ushort expert_id = router.x;
+ ushort i11 = router.y;
+ ushort i1 = router.z;
+ ushort tile_id = router.w;
+
+ if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
+ return;
+ }
+
+ uint expert_offset = expert_id * ne00 * ne01 / 32;
+ uint tile_offset = expert_offset + tile_id * tile_size + i01;
+
+ __private float sum = 0.0f; // each thread calculate partial sum of one output
+
+ // loop along ne00 in block granularity, skip 4 blocks every iter
+ for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
+ // load one block of q
+ uint4 regQ = src0_q[tile_offset + ib00 * ne01];
+ // convert 8 fp4 to fp16
+ half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
+
+ uint offset = i11 * ne00 / 4 + ib00 * 8;
+ float4 shared_y4;
+ shared_y4 = read_imagef(src1, (offset + 0));
+ float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+ shared_y4 = read_imagef(src1, (offset + 4));
+ acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+ fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
+
+ shared_y4 = read_imagef(src1, (offset + 1));
+ acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+ shared_y4 = read_imagef(src1, (offset + 5));
+ acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+ fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
+
+ shared_y4 = read_imagef(src1, (offset + 2));
+ acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+ shared_y4 = read_imagef(src1, (offset + 6));
+ acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+ fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
+
+ shared_y4 = read_imagef(src1, (offset + 3));
+ acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+ shared_y4 = read_imagef(src1, (offset + 7));
+ acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+ uchar regE = src0_e[tile_offset + ib00 * ne01];
+ sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
+ }
+
+ // reduction in local memory, assumes #subgroups=4
+ __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
+ if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
+ // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
+ // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+ // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+ // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+ // 1 outputs per thread in subgroup 0
+ if (sgid == 0) {
+ dst = dst + (offsetd >> 2);
+ dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
+ }
+
+}