diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl new file mode 100644 index 0000000..b4b1e51 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl | |||
| @@ -0,0 +1,156 @@ | |||
| 1 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||
| 2 | #pragma OPENCL EXTENSION cl_khr_subgroups : enable | ||
| 3 | #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable | ||
| 4 | |||
| 5 | #define QK_MXFP4 32 | ||
| 6 | #define N_SIMDGROUP 4 | ||
| 7 | #define SIMDGROUP_WIDTH 64 | ||
| 8 | |||
| 9 | static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) { | ||
| 10 | ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; | ||
| 11 | fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; | ||
| 12 | fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; | ||
| 13 | fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; | ||
| 14 | fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; | ||
| 15 | |||
| 16 | bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; | ||
| 17 | bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; | ||
| 18 | bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; | ||
| 19 | bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; | ||
| 20 | |||
| 21 | fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; | ||
| 22 | fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; | ||
| 23 | fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; | ||
| 24 | fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; | ||
| 25 | |||
| 26 | sign_a.lo = (fp4x8.s0 << 12) & 0x8000; | ||
| 27 | sign_a.hi = (fp4x8.s0 << 8) & 0x8000; | ||
| 28 | sign_b.lo = (fp4x8.s0 << 4) & 0x8000; | ||
| 29 | sign_b.hi = fp4x8.s0 & 0x8000; | ||
| 30 | |||
| 31 | fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; | ||
| 32 | fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; | ||
| 33 | |||
| 34 | ushort2 fp16_packed_a_1, fp16_packed_b_1; | ||
| 35 | fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; | ||
| 36 | fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; | ||
| 37 | fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; | ||
| 38 | fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; | ||
| 39 | |||
| 40 | bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; | ||
| 41 | bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; | ||
| 42 | bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; | ||
| 43 | bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; | ||
| 44 | |||
| 45 | fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; | ||
| 46 | fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; | ||
| 47 | fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; | ||
| 48 | fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; | ||
| 49 | |||
| 50 | sign_a.lo = (fp4x8.s1 << 12) & 0x8000; | ||
| 51 | sign_a.hi = (fp4x8.s1 << 8) & 0x8000; | ||
| 52 | sign_b.lo = (fp4x8.s1 << 4) & 0x8000; | ||
| 53 | sign_b.hi = fp4x8.s1 & 0x8000; | ||
| 54 | |||
| 55 | fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; | ||
| 56 | fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; | ||
| 57 | |||
| 58 | return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); | ||
| 59 | } | ||
| 60 | |||
| 61 | static inline float e8m0_to_fp32(uchar x) { | ||
| 62 | int bits; | ||
| 63 | bits = (x == 0) ? 0x00400000 : ((uint) x << 23); | ||
| 64 | return as_float(bits); | ||
| 65 | } | ||
| 66 | |||
| 67 | |||
| 68 | __attribute__((qcom_reqd_sub_group_size("half"))) | ||
| 69 | __kernel void kernel_gemv_moe_mxfp4_f32( | ||
| 70 | __global uint4 * src0_q, | ||
| 71 | __global uchar * src0_e, | ||
| 72 | __read_only image1d_buffer_t src1, | ||
| 73 | __global uint * src2, | ||
| 74 | __global float * dst, | ||
| 75 | ulong offsetd, | ||
| 76 | int ne00, | ||
| 77 | int ne01, | ||
| 78 | int ne11 | ||
| 79 | ) { | ||
| 80 | uint i01 = get_global_id(0); | ||
| 81 | uint i20 = get_global_id(2); | ||
| 82 | uint sgid = get_local_id(1); | ||
| 83 | uint slid = get_sub_group_local_id(); | ||
| 84 | |||
| 85 | uint i11 = i20 % ne11; | ||
| 86 | |||
| 87 | uint expert_id = src2[i20]; | ||
| 88 | uint expert_offset = expert_id * ne00 * ne01 / 32; | ||
| 89 | |||
| 90 | __private float sum = 0.0f; // each thread calculate partial sum of one output | ||
| 91 | |||
| 92 | // loop along ne00 in block granularity, skip 4 blocks every iter | ||
| 93 | for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { | ||
| 94 | |||
| 95 | // load one block of q | ||
| 96 | uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01]; | ||
| 97 | |||
| 98 | uint offset = i11 * ne00 / 4 + ib00 * 8; | ||
| 99 | |||
| 100 | half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); | ||
| 101 | |||
| 102 | float4 shared_y4; | ||
| 103 | shared_y4 = read_imagef(src1, (offset + 0)); | ||
| 104 | float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); | ||
| 105 | |||
| 106 | shared_y4 = read_imagef(src1, (offset + 4)); | ||
| 107 | acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); | ||
| 108 | |||
| 109 | |||
| 110 | fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); | ||
| 111 | |||
| 112 | shared_y4 = read_imagef(src1, (offset + 1)); | ||
| 113 | acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); | ||
| 114 | |||
| 115 | shared_y4 = read_imagef(src1, (offset + 5)); | ||
| 116 | acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); | ||
| 117 | |||
| 118 | |||
| 119 | fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); | ||
| 120 | |||
| 121 | shared_y4 = read_imagef(src1, (offset + 2)); | ||
| 122 | acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); | ||
| 123 | |||
| 124 | shared_y4 = read_imagef(src1, (offset + 6)); | ||
| 125 | acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); | ||
| 126 | |||
| 127 | |||
| 128 | fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); | ||
| 129 | |||
| 130 | shared_y4 = read_imagef(src1, (offset + 3)); | ||
| 131 | acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); | ||
| 132 | |||
| 133 | shared_y4 = read_imagef(src1, (offset + 7)); | ||
| 134 | acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); | ||
| 135 | |||
| 136 | uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; | ||
| 137 | sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); | ||
| 138 | } | ||
| 139 | |||
| 140 | // reduction in local memory, assumes #subgroups=4 | ||
| 141 | __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; | ||
| 142 | if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; | ||
| 143 | if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; | ||
| 144 | if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; | ||
| 145 | barrier(CLK_LOCAL_MEM_FENCE); | ||
| 146 | if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; | ||
| 147 | if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; | ||
| 148 | if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; | ||
| 149 | |||
| 150 | // 1 outputs per thread in subgroup 0 | ||
| 151 | if (sgid == 0) { | ||
| 152 | dst = dst + (offsetd >> 2); | ||
| 153 | dst[i01 + i20 * ne01] = sum; | ||
| 154 | } | ||
| 155 | |||
| 156 | } | ||
