summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl194
1 files changed, 194 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
new file mode 100644
index 0000000..819e519
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
@@ -0,0 +1,194 @@
1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define QK4_0 32
22#define QR4_0 2
23#define QK4_1 32
24#define QR4_1 2
25#define QK5_0 32
26#define QR5_0 2
27#define QK5_1 32
28#define QR5_1 2
29#define QK8_0 32
30#define QR8_0 1
31#define QK_K 256
32#define K_QUANTS_PER_ITERATION 2
33
34typedef char int8_t;
35typedef uchar uint8_t;
36typedef short int16_t;
37typedef ushort uint16_t;
38typedef int int32_t;
39typedef uint uint32_t;
40
41//------------------------------------------------------------------------------
42// block_q6_K
43//------------------------------------------------------------------------------
44// 6-bit quantization
45// weight is represented as x = a * q
46// 16 blocks of 16 elements each
47// Effectively 6.5625 bits per weight
48typedef struct {
49 uint8_t ql[QK_K/2]; // quants, lower 4 bits
50 uint8_t qh[QK_K/4]; // quants, upper 2 bits
51 int8_t scales[QK_K/16]; // scales, quantized with 8 bits
52 half d; // super-block scale
53} block_q6_K;
54
55//------------------------------------------------------------------------------
56// kernel_mul_mv_q6_K_f32
57//------------------------------------------------------------------------------
58
59#undef N_DST
60#undef N_SIMDGROUP
61#undef N_SIMDWIDTH
62
63#ifdef INTEL_GPU
64#define N_DST 1 // number of rows each SIMD group works on
65#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
66#define N_SIMDWIDTH 16 // SIMD group size
67#elif defined (ADRENO_GPU)
68#define N_DST 1
69#define N_SIMDGROUP 2
70#define N_SIMDWIDTH 64
71#endif
72
73#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
74
75#ifdef INTEL_GPU
76REQD_SUBGROUP_SIZE_16
77#elif defined (ADRENO_GPU)
78REQD_SUBGROUP_SIZE_64
79#endif
80kernel void kernel_mul_mv_q6_K_f32(
81 global void * src0,
82 ulong offset0,
83 global float * src1,
84 ulong offset1,
85 global float * dst,
86 ulong offsetd,
87 int ne00,
88 int ne01,
89 int ne02,
90 int ne10,
91 int ne12,
92 int ne0,
93 int ne1,
94 int r2,
95 int r3
96) {
97 src0 = (global void*)((global char*)src0 + offset0);
98 src1 = (global float*)((global char*)src1 + offset1);
99 dst = (global float*)((global char*)dst + offsetd);
100
101 uchar kmask1 = 0x03;
102 uchar kmask2 = 0x0C;
103 uchar kmask3 = 0x30;
104 uchar kmask4 = 0xC0;
105
106 int nb = ne00/QK_K;
107
108 int r0 = get_group_id(0);
109 int r1 = get_group_id(1);
110 int im = get_group_id(2);
111
112 int row = N_SIMDGROUP * r0 + get_sub_group_id();
113
114 if (row >= ne01) {
115 return;
116 }
117
118 int i12 = im%ne12;
119 int i13 = im/ne12;
120
121 ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
122
123 global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
124 global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
125
126 float sumf = 0;
127
128 // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
129 // block. Values in a subblock shares a scale that is quantized with 8 bits;
130 // the entire block shares a single floating point scale.
131 // For work distribution, each thread processes a subblock (16 weights), hence
132 // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
133 // (super) blocks -- this is the block stride.
134 // The 16 threads that process a (super) block are split into 2 portions, each has
135 // 8 threads; each portion works on 8 subblocks.
136 // For subgroup of 16 threads, the entire subgroup works on a single (super) block
137 // before moving to the next (super) block. Thread0 - thread7 work on the
138 // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
139 // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
140 // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
141 // works on a total of 16 weight values.
142 int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
143 int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
144 int ip = tid/8; // first or second half of (super) block (0 or 1)
145 int il = tid%8; // each half has 8 parts, one per scale
146 int n = 4; // 4 scales at a time (and 4 sums)
147 int l0 = n*il; // offset into half-block, 0..28
148 int is = 8*ip + l0/16; // 0, 1, 8, 9
149
150 int y_offset = 128*ip + l0;
151 int q_offset_l = 64*ip + l0;
152 int q_offset_h = 32*ip + l0;
153
154 for (int i = ix; i < nb; i += BLOCK_STRIDE) {
155
156 global uint8_t * q1 = x[i].ql + q_offset_l;
157 global uint8_t * q2 = q1 + QK_K/8;
158 global uint8_t * qh = x[i].qh + q_offset_h;
159 global int8_t * sc = x[i].scales + is;
160
161 global float * y = yy + i * QK_K + y_offset;
162
163 float dall = x[i].d;
164
165 float4 sums = {0.f, 0.f, 0.f, 0.f};
166
167 sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
168 sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
169 sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
170 sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
171
172 sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
173 sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
174 sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
175 sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
176
177 sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
178 sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
179 sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
180 sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
181
182 sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
183 sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
184 sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
185 sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
186
187 sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
188 }
189
190 float tot = sub_group_reduce_add(sumf);
191 if (get_sub_group_local_id() == 0) {
192 dst[r1*ne0 + im*ne0*ne1 + row] = tot;
193 }
194}