aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl156
1 files changed, 156 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl
new file mode 100644
index 0000000..b4b1e51
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl
@@ -0,0 +1,156 @@
1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
3#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
4
5#define QK_MXFP4 32
6#define N_SIMDGROUP 4
7#define SIMDGROUP_WIDTH 64
8
9static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
10 ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
11 fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
12 fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
13 fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
14 fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
15
16 bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
17 bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
18 bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
19 bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
20
21 fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
22 fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
23 fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
24 fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
25
26 sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
27 sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
28 sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
29 sign_b.hi = fp4x8.s0 & 0x8000;
30
31 fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
32 fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
33
34 ushort2 fp16_packed_a_1, fp16_packed_b_1;
35 fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
36 fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
37 fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
38 fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
39
40 bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
41 bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
42 bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
43 bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
44
45 fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
46 fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
47 fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
48 fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
49
50 sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
51 sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
52 sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
53 sign_b.hi = fp4x8.s1 & 0x8000;
54
55 fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
56 fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
57
58 return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
59}
60
61static inline float e8m0_to_fp32(uchar x) {
62 int bits;
63 bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
64 return as_float(bits);
65}
66
67
68__attribute__((qcom_reqd_sub_group_size("half")))
69__kernel void kernel_gemv_moe_mxfp4_f32(
70 __global uint4 * src0_q,
71 __global uchar * src0_e,
72 __read_only image1d_buffer_t src1,
73 __global uint * src2,
74 __global float * dst,
75 ulong offsetd,
76 int ne00,
77 int ne01,
78 int ne11
79) {
80 uint i01 = get_global_id(0);
81 uint i20 = get_global_id(2);
82 uint sgid = get_local_id(1);
83 uint slid = get_sub_group_local_id();
84
85 uint i11 = i20 % ne11;
86
87 uint expert_id = src2[i20];
88 uint expert_offset = expert_id * ne00 * ne01 / 32;
89
90 __private float sum = 0.0f; // each thread calculate partial sum of one output
91
92 // loop along ne00 in block granularity, skip 4 blocks every iter
93 for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
94
95 // load one block of q
96 uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
97
98 uint offset = i11 * ne00 / 4 + ib00 * 8;
99
100 half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
101
102 float4 shared_y4;
103 shared_y4 = read_imagef(src1, (offset + 0));
104 float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
105
106 shared_y4 = read_imagef(src1, (offset + 4));
107 acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
108
109
110 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
111
112 shared_y4 = read_imagef(src1, (offset + 1));
113 acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
114
115 shared_y4 = read_imagef(src1, (offset + 5));
116 acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
117
118
119 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
120
121 shared_y4 = read_imagef(src1, (offset + 2));
122 acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
123
124 shared_y4 = read_imagef(src1, (offset + 6));
125 acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
126
127
128 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
129
130 shared_y4 = read_imagef(src1, (offset + 3));
131 acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
132
133 shared_y4 = read_imagef(src1, (offset + 7));
134 acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
135
136 uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
137 sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
138 }
139
140 // reduction in local memory, assumes #subgroups=4
141 __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
142 if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
143 if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
144 if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
145 barrier(CLK_LOCAL_MEM_FENCE);
146 if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
147 if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
148 if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
149
150 // 1 outputs per thread in subgroup 0
151 if (sgid == 0) {
152 dst = dst + (offsetd >> 2);
153 dst[i01 + i20 * ne01] = sum;
154 }
155
156}