1//
2// MIT license
3// Copyright (C) 2024 Intel Corporation
4// SPDX-License-Identifier: MIT
5//
6
7//
8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9// See https://llvm.org/LICENSE.txt for license information.
10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11//
12
13#include "mmq.hpp"
14#include "vecdotq.hpp"
15
16typedef void (*allocate_tiles_sycl_t)(
17 int** x_ql,
18 sycl::half2** x_dm,
19 int** x_qh,
20 int** x_sc);
21typedef void (*load_tiles_sycl_t)(
22 const void* __restrict__ vx,
23 int* __restrict__ x_ql,
24 sycl::half2* __restrict__ x_dm,
25 int* __restrict__ x_qh,
26 int* __restrict__ x_sc,
27 const int& i_offset,
28 const int& i_max,
29 const int& k,
30 const int& blocks_per_row);
31typedef float (*vec_dot_q_mul_mat_sycl_t)(
32 const int* __restrict__ x_ql,
33 const sycl::half2* __restrict__ x_dm,
34 const int* __restrict__ x_qh,
35 const int* __restrict__ x_sc,
36 const int* __restrict__ y_qs,
37 const sycl::half2* __restrict__ y_ms,
38 const int& i,
39 const int& j,
40 const int& k);
41
42
43template <int mmq_y>
44static __dpct_inline__ void
45allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
46 int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
47 (void)x_qh; (void)x_sc;
48
49 *x_ql = tile_x_qs_q4_0;
50 *x_dm = (sycl::half2 *)tile_x_d_q4_0;
51}
52
53template <int mmq_y, int nwarps, bool need_check>
54static __dpct_inline__ void
55load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
56 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
57 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
58 const int &k, const int &blocks_per_row) {
59 (void)x_qh; (void)x_sc;
60 GGML_SYCL_ASSUME(i_offset >= 0);
61 GGML_SYCL_ASSUME(i_offset < nwarps);
62 GGML_SYCL_ASSUME(k >= 0);
63 GGML_SYCL_ASSUME(k < WARP_SIZE);
64
65 const int kbx = k / QI4_0;
66 const int kqsx = k % QI4_0;
67
68 const block_q4_0 * bx0 = (const block_q4_0 *) vx;
69
70 float * x_dmf = (float *) x_dm;
71
72#pragma unroll
73 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
74 int i = i0 + i_offset;
75
76 if (need_check) {
77 i = sycl::min(i, i_max);
78 }
79
80 const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
81
82 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
83 // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
84 }
85
86 const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
87 const int kbxd = k % blocks_per_tile_x_row;
88
89#pragma unroll
90 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
91 int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
92
93 if (need_check) {
94 i = sycl::min(i, i_max);
95 }
96
97 const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
98
99 x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
100 }
101}
102
103static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
104 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
105 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
106 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
107 const int &i, const int &j, const int &k) {
108 (void)x_qh; (void)x_sc;
109
110 const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
111 const float * x_dmf = (const float *) x_dm;
112
113 int u[2*VDR_Q4_0_Q8_1_MMQ];
114
115#pragma unroll
116 for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
117 u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
118 u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
119 }
120
121 return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
122 (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
123 y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
124}
125
126template <int mmq_y>
127static __dpct_inline__ void
128allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
129 int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
130 (void)x_qh; (void)x_sc;
131
132 *x_ql = tile_x_qs_q4_1;
133 *x_dm = tile_x_dm_q4_1;
134}
135
136
137template <int mmq_y, int nwarps, bool need_check>
138static __dpct_inline__ void
139load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
140 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
141 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
142 const int &k, const int &blocks_per_row) {
143 (void)x_qh; (void)x_sc;
144
145 GGML_SYCL_ASSUME(i_offset >= 0);
146 GGML_SYCL_ASSUME(i_offset < nwarps);
147 GGML_SYCL_ASSUME(k >= 0);
148 GGML_SYCL_ASSUME(k < WARP_SIZE);
149
150 const int kbx = k / QI4_1;
151 const int kqsx = k % QI4_1;
152
153 const block_q4_1 * bx0 = (const block_q4_1 *) vx;
154
155#pragma unroll
156 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
157 int i = i0 + i_offset;
158
159 if (need_check) {
160 i = sycl::min(i, i_max);
161 }
162
163 const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
164
165 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
166 }
167
168 const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
169 const int kbxd = k % blocks_per_tile_x_row;
170
171#pragma unroll
172 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
173 int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
174
175 if (need_check) {
176 i = sycl::min(i, i_max);
177 }
178
179 const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
180
181 x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
182 }
183}
184
185static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
186 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
187 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
188 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
189 const int &i, const int &j, const int &k) {
190 (void)x_qh; (void)x_sc;
191
192 const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
193
194 int u[2*VDR_Q4_1_Q8_1_MMQ];
195
196#pragma unroll
197 for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
198 u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
199 u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
200 }
201
202 return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
203 (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
204 y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
205}
206
207template <int mmq_y>
208static __dpct_inline__ void
209allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
210 int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
211 (void)x_qh; (void)x_sc;
212
213 *x_ql = tile_x_ql_q5_0;
214 *x_dm = (sycl::half2 *)tile_x_d_q5_0;
215}
216
217template <int mmq_y, int nwarps, bool need_check>
218static __dpct_inline__ void
219load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
220 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
221 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
222 const int &k, const int &blocks_per_row) {
223 (void)x_qh; (void)x_sc;
224
225 GGML_SYCL_ASSUME(i_offset >= 0);
226 GGML_SYCL_ASSUME(i_offset < nwarps);
227 GGML_SYCL_ASSUME(k >= 0);
228 GGML_SYCL_ASSUME(k < WARP_SIZE);
229
230 const int kbx = k / QI5_0;
231 const int kqsx = k % QI5_0;
232
233 const block_q5_0 * bx0 = (const block_q5_0 *) vx;
234
235#pragma unroll
236 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
237 int i = i0 + i_offset;
238
239 if (need_check) {
240 i = sycl::min(i, i_max);
241 }
242
243 const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
244
245 const int ql = get_int_from_uint8(bxi->qs, kqsx);
246 const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
247
248 int qs0 = (ql >> 0) & 0x0F0F0F0F;
249 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
250 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
251 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
252 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
253 qs0 = dpct::vectorized_binary<sycl::char4>(
254 qs0, 0x10101010, dpct::sub_sat()); // subtract 16
255
256 x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
257
258 int qs1 = (ql >> 4) & 0x0F0F0F0F;
259 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
260 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
261 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
262 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
263 qs1 = dpct::vectorized_binary<sycl::char4>(
264 qs1, 0x10101010, dpct::sub_sat()); // subtract 16
265
266 x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
267 }
268
269 const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
270 const int kbxd = k % blocks_per_tile_x_row;
271 float * x_dmf = (float *) x_dm;
272
273#pragma unroll
274 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
275 int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
276
277 if (need_check) {
278 i = sycl::min(i, i_max);
279 }
280
281 const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
282
283 x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
284 }
285}
286
287static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
288 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
289 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
290 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
291 const int &i, const int &j, const int &k) {
292 (void)x_qh; (void)x_sc;
293
294 const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
295 const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
296 const float * x_dmf = (const float *) x_dm;
297 const float * y_df = (const float *) y_ds;
298
299 int u[2*VDR_Q5_0_Q8_1_MMQ];
300
301#pragma unroll
302 for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
303 u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
304 u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
305 }
306
307 return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
308 (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
309}
310
311template <int mmq_y>
312static __dpct_inline__ void
313allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
314 int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
315 (void)x_qh; (void)x_sc;
316
317 *x_ql = tile_x_ql_q5_1;
318 *x_dm = tile_x_dm_q5_1;
319}
320
321template <int mmq_y, int nwarps, bool need_check>
322static __dpct_inline__ void
323load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
324 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
325 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
326 const int &k, const int &blocks_per_row) {
327 (void)x_qh; (void)x_sc;
328
329 GGML_SYCL_ASSUME(i_offset >= 0);
330 GGML_SYCL_ASSUME(i_offset < nwarps);
331 GGML_SYCL_ASSUME(k >= 0);
332 GGML_SYCL_ASSUME(k < WARP_SIZE);
333
334 const int kbx = k / QI5_1;
335 const int kqsx = k % QI5_1;
336
337 const block_q5_1 * bx0 = (const block_q5_1 *) vx;
338
339#pragma unroll
340 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
341 int i = i0 + i_offset;
342
343 if (need_check) {
344 i = sycl::min(i, i_max);
345 }
346
347 const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
348
349 const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
350 const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
351
352 int qs0 = (ql >> 0) & 0x0F0F0F0F;
353 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
354 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
355 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
356 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
357
358 x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
359
360 int qs1 = (ql >> 4) & 0x0F0F0F0F;
361 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
362 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
363 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
364 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
365
366 x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
367 }
368
369 const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
370 const int kbxd = k % blocks_per_tile_x_row;
371
372#pragma unroll
373 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
374 int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
375
376 if (need_check) {
377 i = sycl::min(i, i_max);
378 }
379
380 const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
381
382 x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
383 }
384}
385
386static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
387 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
388 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
389 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
390 const int &i, const int &j, const int &k) {
391 (void)x_qh; (void)x_sc;
392
393 const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
394 const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
395
396 int u[2*VDR_Q5_1_Q8_1_MMQ];
397
398#pragma unroll
399 for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
400 u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
401 u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
402 }
403
404 return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
405 (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
406}
407
408template <int mmq_y>
409static __dpct_inline__ void
410allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
411 int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
412 (void)x_qh; (void)x_sc;
413
414 *x_ql = tile_x_qs_q8_0;
415 *x_dm = (sycl::half2 *)tile_x_d_q8_0;
416}
417
418template <int mmq_y, int nwarps, bool need_check>
419static __dpct_inline__ void
420load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
421 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
422 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
423 const int &k, const int &blocks_per_row) {
424 (void)x_qh; (void)x_sc;
425
426 GGML_SYCL_ASSUME(i_offset >= 0);
427 GGML_SYCL_ASSUME(i_offset < nwarps);
428 GGML_SYCL_ASSUME(k >= 0);
429 GGML_SYCL_ASSUME(k < WARP_SIZE);
430
431 const int kbx = k / QI8_0;
432 const int kqsx = k % QI8_0;
433 float * x_dmf = (float *) x_dm;
434
435 const block_q8_0 * bx0 = (const block_q8_0 *) vx;
436
437#pragma unroll
438 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
439 int i = i0 + i_offset;
440
441 if (need_check) {
442 i = sycl::min(i, i_max);
443 }
444
445 const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
446
447 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
448 }
449
450 const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
451 const int kbxd = k % blocks_per_tile_x_row;
452
453#pragma unroll
454 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
455 int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
456
457 if (need_check) {
458 i = sycl::min(i, i_max);
459 }
460
461 const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
462
463 x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
464 }
465}
466
467static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
468 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
469 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
470 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
471 const int &i, const int &j, const int &k) {
472 (void)x_qh; (void)x_sc;
473
474 const float * x_dmf = (const float *) x_dm;
475 const float * y_df = (const float *) y_ds;
476
477 return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
478 (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
479 y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
480}
481
482template <int mmq_y>
483static __dpct_inline__ void
484allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
485 int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
486 int *tile_x_sc_q2_K) {
487 (void)x_qh;
488
489 *x_ql = tile_x_ql_q2_K;
490 *x_dm = tile_x_dm_q2_K;
491 *x_sc = tile_x_sc_q2_K;
492}
493
494template <int mmq_y, int nwarps, bool need_check>
495static __dpct_inline__ void
496load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
497 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
498 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
499 const int &k, const int &blocks_per_row) {
500 (void)x_qh;
501
502 GGML_SYCL_ASSUME(i_offset >= 0);
503 GGML_SYCL_ASSUME(i_offset < nwarps);
504 GGML_SYCL_ASSUME(k >= 0);
505 GGML_SYCL_ASSUME(k < WARP_SIZE);
506
507 const int kbx = k / QI2_K;
508 const int kqsx = k % QI2_K;
509
510 const block_q2_K * bx0 = (const block_q2_K *) vx;
511
512#pragma unroll
513 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
514 int i = i0 + i_offset;
515
516 if (need_check) {
517 i = sycl::min(i, i_max);
518 }
519
520 const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
521
522 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
523 }
524
525 const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
526 const int kbxd = k % blocks_per_tile_x_row;
527
528#pragma unroll
529 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
530 int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
531
532 if (need_check) {
533 i = sycl::min(i, i_max);
534 }
535
536 const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
537
538 x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
539 }
540
541#pragma unroll
542 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
543 int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
544
545 if (need_check) {
546 i = sycl::min(i, i_max);
547 }
548
549 const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
550
551 x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
552 }
553}
554
555#define VDR_Q2_K_Q8_1_MMQ 2
556// contiguous u/y values
557static __dpct_inline__ float
558vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
559 const uint8_t *__restrict__ scales,
560 const sycl::half2 &dm2, const float &d8) {
561
562 int sumi_d = 0;
563 int sumi_m = 0;
564
565#pragma unroll
566 for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
567 int sumi_d_sc = 0;
568
569 const int sc = scales[i0 / (QI8_1/2)];
570
571 // fill int with 4x m
572 int m = sc >> 4;
573 m |= m << 8;
574 m |= m << 16;
575
576#pragma unroll
577 for (int i = i0; i < i0 + QI8_1/2; ++i) {
578 sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
579 sumi_m = dpct::dp4a(m, u[i],
580 sumi_m); // multiply sum of q8_1 values with m
581 }
582
583 sumi_d += sumi_d_sc * (sc & 0xF);
584 }
585
586 const sycl::float2 dm2f =
587 dm2.convert<float, sycl::rounding_mode::automatic>();
588
589 return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
590}
591
592static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
593 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
594 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
595 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
596 const int &i, const int &j, const int &k) {
597 (void)x_qh;
598
599 const int kbx = k / QI2_K;
600 const int ky = (k % QI2_K) * QR2_K;
601 const float * y_df = (const float *) y_ds;
602
603 int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
604
605 const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
606 const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
607
608#pragma unroll
609 for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
610 v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
611 }
612
613 const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
614
615 const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
616 return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
617}
618
619template <int mmq_y>
620static __dpct_inline__ void
621allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
622 int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
623 int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
624
625 *x_ql = tile_x_ql_q3_K;
626 *x_dm = tile_x_dm_q3_K;
627 *x_qh = tile_x_qh_q3_K;
628 *x_sc = tile_x_sc_q3_K;
629}
630
631template <int mmq_y, int nwarps, bool need_check>
632static __dpct_inline__ void
633load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
634 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
635 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
636 const int &k, const int &blocks_per_row) {
637
638 GGML_SYCL_ASSUME(i_offset >= 0);
639 GGML_SYCL_ASSUME(i_offset < nwarps);
640 GGML_SYCL_ASSUME(k >= 0);
641 GGML_SYCL_ASSUME(k < WARP_SIZE);
642
643 const int kbx = k / QI3_K;
644 const int kqsx = k % QI3_K;
645
646 const block_q3_K * bx0 = (const block_q3_K *) vx;
647
648#pragma unroll
649 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
650 int i = i0 + i_offset;
651
652 if (need_check) {
653 i = sycl::min(i, i_max);
654 }
655
656 const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
657
658 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
659 }
660
661 const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
662 const int kbxd = k % blocks_per_tile_x_row;
663 float * x_dmf = (float *) x_dm;
664
665#pragma unroll
666 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
667 int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
668
669 if (need_check) {
670 i = sycl::min(i, i_max);
671 }
672
673 const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
674
675 x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
676 }
677
678#pragma unroll
679 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
680 int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
681
682 if (need_check) {
683 i = sycl::min(i, i_max);
684 }
685
686 const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
687
688 // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
689 x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
690 }
691
692#pragma unroll
693 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
694 int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
695
696 if (need_check) {
697 i = sycl::min(i, i_max);
698 }
699
700 const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
701
702 const int ksc = k % (QI3_K/4);
703
704 const int ksc_low = ksc % (QI3_K/8);
705 const int shift_low = 4 * (ksc / (QI3_K/8));
706 const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
707
708 const int ksc_high = QI3_K/8;
709 const int shift_high = 2 * ksc;
710 const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
711
712 const int sc = dpct::vectorized_binary<sycl::char4>(
713 sc_low | sc_high, 0x20202020, dpct::sub_sat());
714
715 x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
716 }
717}
718
719#define VDR_Q3_K_Q8_1_MMQ 2
720// contiguous u/y values
721static __dpct_inline__ float
722vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
723 const int8_t *__restrict__ scales, const float &d3,
724 const float &d8) {
725
726 int sumi = 0;
727
728#pragma unroll
729 for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
730 int sumi_sc = 0;
731
732 for (int i = i0; i < i0 + QI8_1/2; ++i) {
733 sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
734 }
735
736 sumi += sumi_sc * scales[i0 / (QI8_1/2)];
737 }
738
739 return d3*d8 * sumi;
740}
741
742static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
743 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
744 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
745 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
746 const int &i, const int &j, const int &k) {
747
748 const int kbx = k / QI3_K;
749 const int ky = (k % QI3_K) * QR3_K;
750 const float * x_dmf = (const float *) x_dm;
751 const float * y_df = (const float *) y_ds;
752
753 const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
754
755 int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
756
757#pragma unroll
758 for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
759 const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
760 const int shift = 2 * ((ky % 32) / 8);
761 const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
762
763 const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
764 const int vlh = (vh << 2) & 0x04040404;
765
766 v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
767 }
768
769 const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
770 return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
771}
772
773template <int mmq_y>
774static __dpct_inline__ void
775allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
776 int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
777 int *tile_x_sc_q4_K) {
778 (void)x_qh;
779
780 *x_ql = tile_x_ql_q4_K;
781 *x_dm = tile_x_dm_q4_K;
782 *x_sc = tile_x_sc_q4_K;
783}
784
785template <int mmq_y, int nwarps, bool need_check>
786static __dpct_inline__ void
787load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
788 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
789 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
790 const int &k, const int &blocks_per_row) {
791 (void)x_qh;
792
793 GGML_SYCL_ASSUME(i_offset >= 0);
794 GGML_SYCL_ASSUME(i_offset < nwarps);
795 GGML_SYCL_ASSUME(k >= 0);
796 GGML_SYCL_ASSUME(k < WARP_SIZE);
797
798 const int kbx = k / QI4_K; // == 0 if QK_K == 256
799 const int kqsx = k % QI4_K; // == k if QK_K == 256
800
801 const block_q4_K * bx0 = (const block_q4_K *) vx;
802
803#pragma unroll
804 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
805 int i = i0 + i_offset;
806
807 if (need_check) {
808 i = sycl::min(i, i_max);
809 }
810
811 const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
812
813 x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
814 }
815
816 constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
817 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
818
819#pragma unroll
820 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
821 int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
822
823 if (need_check) {
824 i = sycl::min(i, i_max);
825 }
826
827 const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
828
829#if QK_K == 256
830 x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
831#else
832 x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
833#endif
834 }
835
836#pragma unroll
837 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
838 int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
839
840 if (need_check) {
841 i = sycl::min(i, i_max);
842 }
843
844 const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
845
846 const int * scales = (const int *) bxi->scales;
847
848 const int ksc = k % (WARP_SIZE/8);
849
850 // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
851 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
852 scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
853
854 x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
855 }
856}
857
858
859#define VDR_Q4_K_Q8_1_MMQ 8
860
861// contiguous u/y values
862static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
863 const int *__restrict__ v, const int *__restrict__ u,
864 const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
865 const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
866
867 float sumf_d = 0.0f;
868 float sumf_m = 0.0f;
869
870#pragma unroll
871 for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
872 int sumi_d = 0;
873
874#pragma unroll
875 for (int j = 0; j < QI8_1; ++j) {
876 sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
877 u[i * QI8_1 + j], sumi_d); // SIMD dot product
878 }
879
880 const sycl::float2 ds8f =
881 ds8[i].convert<float, sycl::rounding_mode::automatic>();
882
883 sumf_d += ds8f.x() * (sc[i] * sumi_d);
884 sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
885 }
886
887 const sycl::float2 dm4f =
888 dm4.convert<float, sycl::rounding_mode::automatic>();
889
890 return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
891}
892
893
894static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
895 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
896 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
897 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
898 const int &i, const int &j, const int &k) {
899 (void)x_qh;
900
901 const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
902
903 const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
904 return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
905 x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
906}
907
908template <int mmq_y>
909static __dpct_inline__ void
910allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
911 int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
912 int *tile_x_sc_q5_K) {
913 (void)x_qh;
914
915 *x_ql = tile_x_ql_q5_K;
916 *x_dm = tile_x_dm_q5_K;
917 *x_sc = tile_x_sc_q5_K;
918}
919
920template <int mmq_y, int nwarps, bool need_check>
921static __dpct_inline__ void
922load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
923 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
924 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
925 const int &k, const int &blocks_per_row) {
926 (void)x_qh;
927
928 GGML_SYCL_ASSUME(i_offset >= 0);
929 GGML_SYCL_ASSUME(i_offset < nwarps);
930 GGML_SYCL_ASSUME(k >= 0);
931 GGML_SYCL_ASSUME(k < WARP_SIZE);
932
933 const int kbx = k / QI5_K; // == 0 if QK_K == 256
934 const int kqsx = k % QI5_K; // == k if QK_K == 256
935
936 const block_q5_K * bx0 = (const block_q5_K *) vx;
937
938#pragma unroll
939 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
940 int i = i0 + i_offset;
941
942 if (need_check) {
943 i = sycl::min(i, i_max);
944 }
945
946 const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
947 const int ky = QR5_K*kqsx;
948
949 const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
950 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
951 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
952
953 const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
954 const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
955 const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
956
957 const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
958 const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
959
960 x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
961 x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
962 }
963
964 constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
965 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
966
967#pragma unroll
968 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
969 int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
970
971 if (need_check) {
972 i = sycl::min(i, i_max);
973 }
974
975 const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
976
977#if QK_K == 256
978 x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
979#endif
980 }
981
982#pragma unroll
983 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
984 int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
985
986 if (need_check) {
987 i = sycl::min(i, i_max);
988 }
989
990 const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
991
992 const int * scales = (const int *) bxi->scales;
993
994 const int ksc = k % (WARP_SIZE/8);
995
996 // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
997 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
998 scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
999
1000 x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
1001 }
1002}
1003
1004#define VDR_Q5_K_Q8_1_MMQ 8
1005
1006// contiguous u/y values
1007static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
1008 const int *__restrict__ v, const int *__restrict__ u,
1009 const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
1010 const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
1011
1012 float sumf_d = 0.0f;
1013 float sumf_m = 0.0f;
1014
1015#pragma unroll
1016 for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1017 int sumi_d = 0;
1018
1019#pragma unroll
1020 for (int j = 0; j < QI8_1; ++j) {
1021 sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
1022 sumi_d); // SIMD dot product
1023 }
1024
1025 const sycl::float2 ds8f =
1026 ds8[i].convert<float, sycl::rounding_mode::automatic>();
1027
1028 sumf_d += ds8f.x() * (sc[i] * sumi_d);
1029 sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
1030 }
1031
1032 const sycl::float2 dm4f =
1033 dm4.convert<float, sycl::rounding_mode::automatic>();
1034
1035 return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
1036}
1037
1038static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
1039 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1040 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1041 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1042 const int &i, const int &j, const int &k) {
1043 (void)x_qh;
1044
1045 const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
1046
1047 const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
1048 const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
1049 return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
1050 x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
1051}
1052
1053template <int mmq_y>
1054static __dpct_inline__ void
1055allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
1056 int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
1057 (void)x_qh;
1058
1059 *x_ql = tile_x_ql;
1060 *x_dm = tile_x_dm;
1061 *x_sc = tile_x_sc;
1062}
1063
1064template <int mmq_y, int nwarps, bool need_check>
1065static __dpct_inline__ void
1066load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
1067 sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
1068 int *__restrict__ x_sc, const int &i_offset, const int &i_max,
1069 const int &k, const int &blocks_per_row) {
1070 (void)x_qh;
1071
1072 GGML_SYCL_ASSUME(i_offset >= 0);
1073 GGML_SYCL_ASSUME(i_offset < nwarps);
1074 GGML_SYCL_ASSUME(k >= 0);
1075 GGML_SYCL_ASSUME(k < WARP_SIZE);
1076
1077 const int kbx = k / QI6_K; // == 0 if QK_K == 256
1078 const int kqsx = k % QI6_K; // == k if QK_K == 256
1079
1080 const block_q6_K * bx0 = (const block_q6_K *) vx;
1081
1082#pragma unroll
1083 for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1084 int i = i0 + i_offset;
1085
1086 if (need_check) {
1087 i = sycl::min(i, i_max);
1088 }
1089
1090 const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
1091 const int ky = QR6_K*kqsx;
1092
1093 const int ql = get_int_from_uint8(bxi->ql, kqsx);
1094 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1095 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1096
1097 const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1098 const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1099 const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1100
1101 const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
1102 const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
1103
1104 x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
1105 dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
1106 dpct::sub_sat());
1107 x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
1108 dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
1109 dpct::sub_sat());
1110 }
1111
1112 constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1113 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
1114 float * x_dmf = (float *) x_dm;
1115
1116#pragma unroll
1117 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1118 int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
1119
1120 if (need_check) {
1121 i = sycl::min(i, i_max);
1122 }
1123
1124 const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
1125
1126 x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
1127 }
1128
1129#pragma unroll
1130 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1131 int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
1132
1133 if (need_check) {
1134 i = sycl::min(i, i_max);
1135 }
1136
1137 const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
1138
1139 x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
1140 }
1141}
1142
1143#define VDR_Q6_K_Q8_1_MMQ 8
1144
1145// contiguous u/y values
1146static __dpct_inline__ float
1147vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
1148 const int8_t *__restrict__ sc, const float &d6,
1149 const float *__restrict__ d8) {
1150
1151 float sumf_d = 0.0f;
1152
1153#pragma unroll
1154 for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
1155 sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
1156
1157#pragma unroll
1158 for (int i = i0; i < i0 + 2; ++i) {
1159 sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
1160 sumi_d.x()); // SIMD dot product
1161 sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
1162 sumi_d.x()); // SIMD dot product
1163
1164 sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
1165 sumi_d.y()); // SIMD dot product
1166 sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
1167 sumi_d.y()); // SIMD dot product
1168 }
1169
1170 sumf_d += d8[i0 / 4] *
1171 (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
1172 }
1173
1174 return d6 * sumf_d;
1175}
1176
1177static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
1178 const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1179 const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1180 const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1181 const int &i, const int &j, const int &k) {
1182 (void)x_qh;
1183
1184 const float * x_dmf = (const float *) x_dm;
1185 const float * y_df = (const float *) y_ds;
1186
1187 const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
1188
1189 const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
1190 const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
1191 return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
1192}
1193
1194template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
1195 int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
1196 vec_dot_q_mul_mat_sycl_t vec_dot>
1197/*
1198DPCT1110:8: The total declared local variable size in device function mul_mat_q
1199exceeds 128 bytes and may cause high register pressure. Consult with your
1200hardware vendor to find the total register size available and adjust the code,
1201or use smaller sub-group size to avoid high register pressure.
1202*/
1203static __dpct_inline__ void
1204mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
1205 float *__restrict__ dst, const int ncols_x, const int nrows_x,
1206 const int ncols_y, const int nrows_y, const int nrows_dst,
1207 int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
1208 int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
1209 sycl::half2 *tile_y_ds) {
1210
1211 const block_q_t * x = (const block_q_t *) vx;
1212 const block_q8_1 * y = (const block_q8_1 *) vy;
1213
1214 const int blocks_per_row_x = ncols_x / qk;
1215 const int blocks_per_col_y = nrows_y / QK8_1;
1216 const int blocks_per_warp = WARP_SIZE / qi;
1217
1218 const int & ncols_dst = ncols_y;
1219
1220 const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
1221 const int & row_x_0 = row_dst_0;
1222
1223 const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
1224 const int & col_y_0 = col_dst_0;
1225
1226 float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
1227
1228 for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1229
1230 load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
1231 tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
1232 nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
1233 blocks_per_row_x);
1234
1235#pragma unroll
1236 for (int ir = 0; ir < qr; ++ir) {
1237 const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
1238 const int kbxd = kqs / QI8_1;
1239
1240#pragma unroll
1241 for (int i = 0; i < mmq_x; i += nwarps) {
1242 const int col_y_eff = dpct::min(
1243 (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
1244 ncols_y - 1); // to prevent out-of-bounds memory accesses
1245
1246 const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1247
1248 const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
1249 kqs % WARP_SIZE;
1250 tile_y_qs[index_y] = get_int_from_int8_aligned(
1251 by0->qs, item_ct1.get_local_id(2) % QI8_1);
1252 }
1253
1254#pragma unroll
1255 for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1256 const int ids =
1257 (ids0 + item_ct1.get_local_id(1) * QI8_1 +
1258 item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
1259 mmq_x;
1260 const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
1261 const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
1262
1263 // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1264 const sycl::half2 *dsi_src =
1265 &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
1266 ir * (WARP_SIZE / QI8_1) + kby]
1267 .ds;
1268 sycl::half2 *dsi_dst =
1269 &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
1270 if (need_sum) {
1271 *dsi_dst = *dsi_src;
1272 } else {
1273 float * dfi_dst = (float *) dsi_dst;
1274 *dfi_dst = (*dsi_src)[0];
1275 }
1276 }
1277
1278 /*
1279 DPCT1118:9: SYCL group functions and algorithms must be encountered
1280 in converged control flow. You may need to adjust the code.
1281 */
1282 /*
1283 DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
1284 sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1285 better performance if there is no access to global memory.
1286 */
1287 item_ct1.barrier();
1288
1289// #pragma unroll // unrolling this loop causes too much register pressure
1290 for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
1291#pragma unroll
1292 for (int j = 0; j < mmq_x; j += nwarps) {
1293#pragma unroll
1294 for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1295 sum[i / WARP_SIZE][j / nwarps] += vec_dot(
1296 tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1297 tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
1298 item_ct1.get_local_id(1) + j, k);
1299 }
1300 }
1301 }
1302
1303 /*
1304 DPCT1118:10: SYCL group functions and algorithms must be encountered
1305 in converged control flow. You may need to adjust the code.
1306 */
1307 /*
1308 DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
1309 sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1310 better performance if there is no access to global memory.
1311 */
1312 item_ct1.barrier();
1313 }
1314 }
1315
1316#pragma unroll
1317 for (int j = 0; j < mmq_x; j += nwarps) {
1318 const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
1319
1320 if (col_dst >= ncols_dst) {
1321 return;
1322 }
1323
1324#pragma unroll
1325 for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1326 const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
1327
1328 if (row_dst >= nrows_dst) {
1329 continue;
1330 }
1331
1332 dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
1333 }
1334 }
1335}
1336
1337#define MMQ_X_Q4_0_RDNA2 64
1338#define MMQ_Y_Q4_0_RDNA2 128
1339#define NWARPS_Q4_0_RDNA2 8
1340#define MMQ_X_Q4_0_RDNA1 64
1341#define MMQ_Y_Q4_0_RDNA1 64
1342#define NWARPS_Q4_0_RDNA1 8
1343#if defined(SYCL_USE_XMX)
1344#define MMQ_X_Q4_0_AMPERE 4
1345#define MMQ_Y_Q4_0_AMPERE 32
1346#define NWARPS_Q4_0_AMPERE 4
1347#else
1348#define MMQ_X_Q4_0_AMPERE 64
1349#define MMQ_Y_Q4_0_AMPERE 128
1350#define NWARPS_Q4_0_AMPERE 4
1351#endif
1352#define MMQ_X_Q4_0_PASCAL 64
1353#define MMQ_Y_Q4_0_PASCAL 64
1354#define NWARPS_Q4_0_PASCAL 8
1355
1356template <bool need_check> static void
1357 mul_mat_q4_0(
1358 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1359 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1360 const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
1361 int *tile_y_qs, sycl::half2 *tile_y_ds) {
1362 int * tile_x_ql = nullptr;
1363 sycl::half2 *tile_x_dm = nullptr;
1364 int * tile_x_qh = nullptr;
1365 int * tile_x_sc = nullptr;
1366
1367//sycl_todo: change according to hardware
1368
1369 const int mmq_x = MMQ_X_Q4_0_AMPERE;
1370 const int mmq_y = MMQ_Y_Q4_0_AMPERE;
1371 const int nwarps = NWARPS_Q4_0_AMPERE;
1372 allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1373 tile_x_qs_q4_0, tile_x_d_q4_0);
1374 mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
1375 load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
1376 vec_dot_q4_0_q8_1_mul_mat>(
1377 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1378 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1379}
1380
1381#define MMQ_X_Q4_1_RDNA2 64
1382#define MMQ_Y_Q4_1_RDNA2 128
1383#define NWARPS_Q4_1_RDNA2 8
1384#define MMQ_X_Q4_1_RDNA1 64
1385#define MMQ_Y_Q4_1_RDNA1 64
1386#define NWARPS_Q4_1_RDNA1 8
1387#if defined(SYCL_USE_XMX)
1388#define MMQ_X_Q4_1_AMPERE 4
1389#define MMQ_Y_Q4_1_AMPERE 32
1390#define NWARPS_Q4_1_AMPERE 4
1391#else
1392#define MMQ_X_Q4_1_AMPERE 64
1393#define MMQ_Y_Q4_1_AMPERE 128
1394#define NWARPS_Q4_1_AMPERE 4
1395#endif
1396#define MMQ_X_Q4_1_PASCAL 64
1397#define MMQ_Y_Q4_1_PASCAL 64
1398#define NWARPS_Q4_1_PASCAL 8
1399
1400template <bool need_check> static void
1401 mul_mat_q4_1(
1402 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1403 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1404 const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
1405 sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1406 int * tile_x_ql = nullptr;
1407 sycl::half2 *tile_x_dm = nullptr;
1408 int * tile_x_qh = nullptr;
1409 int * tile_x_sc = nullptr;
1410
1411//sycl_todo: change according to hardware
1412 const int mmq_x = MMQ_X_Q4_1_AMPERE;
1413 const int mmq_y = MMQ_Y_Q4_1_AMPERE;
1414 const int nwarps = NWARPS_Q4_1_AMPERE;
1415 allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1416 tile_x_qs_q4_1, tile_x_dm_q4_1);
1417 mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
1418 load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
1419 vec_dot_q4_1_q8_1_mul_mat>(
1420 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1421 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1422}
1423
1424#define MMQ_X_Q5_0_RDNA2 64
1425#define MMQ_Y_Q5_0_RDNA2 128
1426#define NWARPS_Q5_0_RDNA2 8
1427#define MMQ_X_Q5_0_RDNA1 64
1428#define MMQ_Y_Q5_0_RDNA1 64
1429#define NWARPS_Q5_0_RDNA1 8
1430#if defined(SYCL_USE_XMX)
1431#define MMQ_X_Q5_0_AMPERE 4
1432#define MMQ_Y_Q5_0_AMPERE 32
1433#define NWARPS_Q5_0_AMPERE 4
1434#else
1435#define MMQ_X_Q5_0_AMPERE 128
1436#define MMQ_Y_Q5_0_AMPERE 64
1437#define NWARPS_Q5_0_AMPERE 4
1438#endif
1439#define MMQ_X_Q5_0_PASCAL 64
1440#define MMQ_Y_Q5_0_PASCAL 64
1441#define NWARPS_Q5_0_PASCAL 8
1442
1443template <bool need_check> static void
1444 mul_mat_q5_0(
1445 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1446 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1447 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
1448 int *tile_y_qs, sycl::half2 *tile_y_ds) {
1449 int * tile_x_ql = nullptr;
1450 sycl::half2 *tile_x_dm = nullptr;
1451 int * tile_x_qh = nullptr;
1452 int * tile_x_sc = nullptr;
1453
1454//sycl_todo: change according to hardware
1455 const int mmq_x = MMQ_X_Q5_0_AMPERE;
1456 const int mmq_y = MMQ_Y_Q5_0_AMPERE;
1457 const int nwarps = NWARPS_Q5_0_AMPERE;
1458 allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1459 tile_x_ql_q5_0, tile_x_d_q5_0);
1460 mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
1461 load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
1462 vec_dot_q5_0_q8_1_mul_mat>(
1463 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1464 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1465}
1466
1467#define MMQ_X_Q5_1_RDNA2 64
1468#define MMQ_Y_Q5_1_RDNA2 128
1469#define NWARPS_Q5_1_RDNA2 8
1470#define MMQ_X_Q5_1_RDNA1 64
1471#define MMQ_Y_Q5_1_RDNA1 64
1472#define NWARPS_Q5_1_RDNA1 8
1473#if defined(SYCL_USE_XMX)
1474#define MMQ_X_Q5_1_AMPERE 4
1475#define MMQ_Y_Q5_1_AMPERE 32
1476#define NWARPS_Q5_1_AMPERE 4
1477#else
1478#define MMQ_X_Q5_1_AMPERE 128
1479#define MMQ_Y_Q5_1_AMPERE 64
1480#define NWARPS_Q5_1_AMPERE 4
1481#endif
1482#define MMQ_X_Q5_1_PASCAL 64
1483#define MMQ_Y_Q5_1_PASCAL 64
1484#define NWARPS_Q5_1_PASCAL 8
1485
1486template <bool need_check> static void
1487mul_mat_q5_1(
1488 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1489 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1490 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
1491 sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1492 int * tile_x_ql = nullptr;
1493 sycl::half2 *tile_x_dm = nullptr;
1494 int * tile_x_qh = nullptr;
1495 int * tile_x_sc = nullptr;
1496
1497//sycl_todo: change according to hardware
1498 const int mmq_x = MMQ_X_Q5_1_AMPERE;
1499 const int mmq_y = MMQ_Y_Q5_1_AMPERE;
1500 const int nwarps = NWARPS_Q5_1_AMPERE;
1501 allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1502 tile_x_ql_q5_1, tile_x_dm_q5_1);
1503 mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
1504 load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
1505 vec_dot_q5_1_q8_1_mul_mat>(
1506 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1507 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1508}
1509
1510#define MMQ_X_Q8_0_RDNA2 64
1511#define MMQ_Y_Q8_0_RDNA2 128
1512#define NWARPS_Q8_0_RDNA2 8
1513#define MMQ_X_Q8_0_RDNA1 64
1514#define MMQ_Y_Q8_0_RDNA1 64
1515#define NWARPS_Q8_0_RDNA1 8
1516#if defined(SYCL_USE_XMX)
1517#define MMQ_X_Q8_0_AMPERE 4
1518#define MMQ_Y_Q8_0_AMPERE 32
1519#define NWARPS_Q8_0_AMPERE 4
1520#else
1521#define MMQ_X_Q8_0_AMPERE 128
1522#define MMQ_Y_Q8_0_AMPERE 64
1523#define NWARPS_Q8_0_AMPERE 4
1524#endif
1525#define MMQ_X_Q8_0_PASCAL 64
1526#define MMQ_Y_Q8_0_PASCAL 64
1527#define NWARPS_Q8_0_PASCAL 8
1528
1529template <bool need_check> static void
1530 mul_mat_q8_0(
1531 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1532 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1533 const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
1534 int *tile_y_qs, sycl::half2 *tile_y_ds) {
1535 int * tile_x_ql = nullptr;
1536 sycl::half2 *tile_x_dm = nullptr;
1537 int * tile_x_qh = nullptr;
1538 int * tile_x_sc = nullptr;
1539
1540//sycl_todo: change according to hardware
1541 const int mmq_x = MMQ_X_Q8_0_AMPERE;
1542 const int mmq_y = MMQ_Y_Q8_0_AMPERE;
1543 const int nwarps = NWARPS_Q8_0_AMPERE;
1544 allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1545 tile_x_qs_q8_0, tile_x_d_q8_0);
1546 mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
1547 load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
1548 vec_dot_q8_0_q8_1_mul_mat>(
1549 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1550 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1551}
1552
1553#define MMQ_X_Q2_K_RDNA2 64
1554#define MMQ_Y_Q2_K_RDNA2 128
1555#define NWARPS_Q2_K_RDNA2 8
1556#define MMQ_X_Q2_K_RDNA1 128
1557#define MMQ_Y_Q2_K_RDNA1 32
1558#define NWARPS_Q2_K_RDNA1 8
1559#if defined(SYCL_USE_XMX)
1560#define MMQ_X_Q2_K_AMPERE 4
1561#define MMQ_Y_Q2_K_AMPERE 32
1562#define NWARPS_Q2_K_AMPERE 4
1563#else
1564#define MMQ_X_Q2_K_AMPERE 64
1565#define MMQ_Y_Q2_K_AMPERE 128
1566#define NWARPS_Q2_K_AMPERE 4
1567#endif
1568#define MMQ_X_Q2_K_PASCAL 64
1569#define MMQ_Y_Q2_K_PASCAL 64
1570#define NWARPS_Q2_K_PASCAL 8
1571
1572template <bool need_check> static void
1573mul_mat_q2_K(
1574 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1575 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1576 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
1577 sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
1578 sycl::half2 *tile_y_ds) {
1579 int * tile_x_ql = nullptr;
1580 sycl::half2 *tile_x_dm = nullptr;
1581 int * tile_x_qh = nullptr;
1582 int * tile_x_sc = nullptr;
1583
1584//sycl_todo: change according to hardware
1585 const int mmq_x = MMQ_X_Q2_K_AMPERE;
1586 const int mmq_y = MMQ_Y_Q2_K_AMPERE;
1587 const int nwarps = NWARPS_Q2_K_AMPERE;
1588 allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1589 tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
1590 mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
1591 load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
1592 vec_dot_q2_K_q8_1_mul_mat>(
1593 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1594 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1595}
1596
1597#define MMQ_X_Q3_K_RDNA2 128
1598#define MMQ_Y_Q3_K_RDNA2 64
1599#define NWARPS_Q3_K_RDNA2 8
1600#define MMQ_X_Q3_K_RDNA1 32
1601#define MMQ_Y_Q3_K_RDNA1 128
1602#define NWARPS_Q3_K_RDNA1 8
1603#if defined(SYCL_USE_XMX)
1604#define MMQ_X_Q3_K_AMPERE 4
1605#define MMQ_Y_Q3_K_AMPERE 32
1606#define NWARPS_Q3_K_AMPERE 4
1607#else
1608#define MMQ_X_Q3_K_AMPERE 128
1609#define MMQ_Y_Q3_K_AMPERE 128
1610#define NWARPS_Q3_K_AMPERE 4
1611#endif
1612#define MMQ_X_Q3_K_PASCAL 64
1613#define MMQ_Y_Q3_K_PASCAL 64
1614#define NWARPS_Q3_K_PASCAL 8
1615
1616template <bool need_check> static void
1617mul_mat_q3_K(
1618 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1619 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1620 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
1621 sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
1622 int *tile_y_qs, sycl::half2 *tile_y_ds) {
1623 int * tile_x_ql = nullptr;
1624 sycl::half2 *tile_x_dm = nullptr;
1625 int * tile_x_qh = nullptr;
1626 int * tile_x_sc = nullptr;
1627
1628//sycl_todo: change according to hardware
1629 const int mmq_x = MMQ_X_Q3_K_AMPERE;
1630 const int mmq_y = MMQ_Y_Q3_K_AMPERE;
1631 const int nwarps = NWARPS_Q3_K_AMPERE;
1632 allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1633 tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
1634 tile_x_sc_q3_K);
1635 mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
1636 load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
1637 vec_dot_q3_K_q8_1_mul_mat>(
1638 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1639 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1640}
1641
1642#define MMQ_X_Q4_K_RDNA2 64
1643#define MMQ_Y_Q4_K_RDNA2 128
1644#define NWARPS_Q4_K_RDNA2 8
1645#define MMQ_X_Q4_K_RDNA1 32
1646#define MMQ_Y_Q4_K_RDNA1 64
1647#define NWARPS_Q4_K_RDNA1 8
1648#if defined(SYCL_USE_XMX)
1649#define MMQ_X_Q4_K_AMPERE 4
1650#define MMQ_Y_Q4_K_AMPERE 32
1651#define NWARPS_Q4_K_AMPERE 4
1652#else
1653#define MMQ_X_Q4_K_AMPERE 64
1654#define MMQ_Y_Q4_K_AMPERE 128
1655#define NWARPS_Q4_K_AMPERE 4
1656#endif
1657#define MMQ_X_Q4_K_PASCAL 64
1658#define MMQ_Y_Q4_K_PASCAL 64
1659#define NWARPS_Q4_K_PASCAL 8
1660
1661template <bool need_check> static void
1662 mul_mat_q4_K(
1663 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1664 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1665 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
1666 sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
1667 sycl::half2 *tile_y_ds) {
1668 int * tile_x_ql = nullptr;
1669 sycl::half2 *tile_x_dm = nullptr;
1670 int * tile_x_qh = nullptr;
1671 int * tile_x_sc = nullptr;
1672
1673//sycl_todo: change according to hardware
1674 const int mmq_x = MMQ_X_Q4_K_AMPERE;
1675 const int mmq_y = MMQ_Y_Q4_K_AMPERE;
1676 const int nwarps = NWARPS_Q4_K_AMPERE;
1677 allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1678 tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
1679 mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
1680 load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
1681 vec_dot_q4_K_q8_1_mul_mat>(
1682 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1683 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1684}
1685
1686#define MMQ_X_Q5_K_RDNA2 64
1687#define MMQ_Y_Q5_K_RDNA2 128
1688#define NWARPS_Q5_K_RDNA2 8
1689#define MMQ_X_Q5_K_RDNA1 32
1690#define MMQ_Y_Q5_K_RDNA1 64
1691#define NWARPS_Q5_K_RDNA1 8
1692#if defined(SYCL_USE_XMX)
1693#define MMQ_X_Q5_K_AMPERE 4
1694#define MMQ_Y_Q5_K_AMPERE 32
1695#define NWARPS_Q5_K_AMPERE 4
1696#else
1697#define MMQ_X_Q5_K_AMPERE 64
1698#define MMQ_Y_Q5_K_AMPERE 128
1699#define NWARPS_Q5_K_AMPERE 4
1700#endif
1701#define MMQ_X_Q5_K_PASCAL 64
1702#define MMQ_Y_Q5_K_PASCAL 64
1703#define NWARPS_Q5_K_PASCAL 8
1704
1705template <bool need_check> static void
1706mul_mat_q5_K(
1707 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1708 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1709 const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
1710 sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
1711 sycl::half2 *tile_y_ds) {
1712 int * tile_x_ql = nullptr;
1713 sycl::half2 *tile_x_dm = nullptr;
1714 int * tile_x_qh = nullptr;
1715 int * tile_x_sc = nullptr;
1716
1717//sycl_todo: change according to hardware
1718 const int mmq_x = MMQ_X_Q5_K_AMPERE;
1719 const int mmq_y = MMQ_Y_Q5_K_AMPERE;
1720 const int nwarps = NWARPS_Q5_K_AMPERE;
1721 allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1722 tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
1723 mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
1724 load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
1725 vec_dot_q5_K_q8_1_mul_mat>(
1726 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1727 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1728}
1729
1730#define MMQ_X_Q6_K_RDNA2 64
1731#define MMQ_Y_Q6_K_RDNA2 128
1732#define NWARPS_Q6_K_RDNA2 8
1733#define MMQ_X_Q6_K_RDNA1 32
1734#define MMQ_Y_Q6_K_RDNA1 64
1735#define NWARPS_Q6_K_RDNA1 8
1736#if defined(SYCL_USE_XMX)
1737#define MMQ_X_Q6_K_AMPERE 4
1738#define MMQ_Y_Q6_K_AMPERE 32
1739#define NWARPS_Q6_K_AMPERE 4
1740#else
1741#define MMQ_X_Q6_K_AMPERE 64
1742#define MMQ_Y_Q6_K_AMPERE 64
1743#define NWARPS_Q6_K_AMPERE 4
1744#endif
1745#define MMQ_X_Q6_K_PASCAL 64
1746#define MMQ_Y_Q6_K_PASCAL 64
1747#define NWARPS_Q6_K_PASCAL 8
1748
1749template <bool need_check> static void
1750 mul_mat_q6_K(
1751 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1752 const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1753 const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
1754 int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1755 // int * tile_x_ql = nullptr;
1756 // sycl::half2 *tile_x_dm = nullptr;
1757 int * tile_x_qh = nullptr;
1758 // int * tile_x_sc = nullptr;
1759
1760//sycl_todo: change according to hardware
1761 const int mmq_x = MMQ_X_Q6_K_AMPERE;
1762 const int mmq_y = MMQ_Y_Q6_K_AMPERE;
1763 const int nwarps = NWARPS_Q6_K_AMPERE;
1764 allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1765 tile_x_ql, tile_x_dm, tile_x_sc);
1766 mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
1767 load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
1768 vec_dot_q6_K_q8_1_mul_mat>(
1769 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1770 tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1771}
1772
1773static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
1774 float *dst, const int ncols_x,
1775 const int nrows_x, const int ncols_y,
1776 const int nrows_y, const int nrows_dst,
1777 dpct::queue_ptr stream) try {
1778
1779 int id;
1780 SYCL_CHECK(
1781 CHECK_TRY_ERROR(id = get_current_device_id()));
1782 const int compute_capability = ggml_sycl_info().devices[id].cc;
1783
1784 int mmq_x, mmq_y, nwarps;
1785 if (compute_capability >= VER_GEN13) {
1786 mmq_x = MMQ_X_Q4_0_RDNA2;
1787 mmq_y = MMQ_Y_Q4_0_RDNA2;
1788 nwarps = NWARPS_Q4_0_RDNA2;
1789 } else if (compute_capability >= VER_GEN12) {
1790 mmq_x = MMQ_X_Q4_0_RDNA1;
1791 mmq_y = MMQ_Y_Q4_0_RDNA1;
1792 nwarps = NWARPS_Q4_0_RDNA1;
1793 } else if (compute_capability >= VER_GEN9) {
1794 mmq_x = MMQ_X_Q4_0_AMPERE;
1795 mmq_y = MMQ_Y_Q4_0_AMPERE;
1796 nwarps = NWARPS_Q4_0_AMPERE;
1797 } else if (compute_capability >= VER_4VEC) {
1798 mmq_x = MMQ_X_Q4_0_PASCAL;
1799 mmq_y = MMQ_Y_Q4_0_PASCAL;
1800 nwarps = NWARPS_Q4_0_PASCAL;
1801 } else {
1802 GGML_ABORT("fatal error");
1803 }
1804
1805 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1806 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1807 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1808 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1809
1810 if (nrows_x % mmq_y == 0) {
1811 const bool need_check = false;
1812 /*
1813 DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
1814 the limit. To get the device limit, query
1815 info::device::max_work_group_size. Adjust the work-group size if needed.
1816 */
1817 {
1818 dpct::has_capability_or_fail(stream->get_device(),
1819 {sycl::aspect::fp16});
1820
1821 stream->submit([&](sycl::handler &cgh) {
1822 sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1823 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1824 sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1825 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1826 cgh);
1827 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1828 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1829 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1830 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1831
1832 cgh.parallel_for(
1833 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1834 [=](sycl::nd_item<3> item_ct1) {
1835 mul_mat_q4_0<need_check>(
1836 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1837 nrows_dst, item_ct1,
1838 get_pointer(tile_x_qs_q4_0_acc_ct1),
1839 get_pointer(tile_x_d_q4_0_acc_ct1),
1840 get_pointer(tile_y_qs_acc_ct1),
1841 get_pointer(tile_y_ds_acc_ct1));
1842 });
1843 });
1844 }
1845 } else {
1846 const bool need_check = true;
1847 /*
1848 DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
1849 the limit. To get the device limit, query
1850 info::device::max_work_group_size. Adjust the work-group size if needed.
1851 */
1852 {
1853 dpct::has_capability_or_fail(stream->get_device(),
1854 {sycl::aspect::fp16});
1855
1856 stream->submit([&](sycl::handler &cgh) {
1857 sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1858 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1859 sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1860 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1861 cgh);
1862 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1863 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1864 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1865 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1866
1867 cgh.parallel_for(
1868 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1869 [=](sycl::nd_item<3> item_ct1) {
1870 mul_mat_q4_0<need_check>(
1871 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1872 nrows_dst, item_ct1,
1873 get_pointer(tile_x_qs_q4_0_acc_ct1),
1874 get_pointer(tile_x_d_q4_0_acc_ct1),
1875 get_pointer(tile_y_qs_acc_ct1),
1876 get_pointer(tile_y_ds_acc_ct1));
1877 });
1878 });
1879 }
1880 }
1881}
1882catch (sycl::exception const &exc) {
1883 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1884 << ", line:" << __LINE__ << std::endl;
1885 std::exit(1);
1886}
1887
1888static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
1889 float *dst, const int ncols_x,
1890 const int nrows_x, const int ncols_y,
1891 const int nrows_y, const int nrows_dst,
1892 dpct::queue_ptr stream) try {
1893
1894 int id;
1895 SYCL_CHECK(
1896 CHECK_TRY_ERROR(id = get_current_device_id()));
1897 const int compute_capability = ggml_sycl_info().devices[id].cc;
1898
1899 int mmq_x, mmq_y, nwarps;
1900 if (compute_capability >= VER_GEN13) {
1901 mmq_x = MMQ_X_Q4_1_RDNA2;
1902 mmq_y = MMQ_Y_Q4_1_RDNA2;
1903 nwarps = NWARPS_Q4_1_RDNA2;
1904 } else if (compute_capability >= VER_GEN12) {
1905 mmq_x = MMQ_X_Q4_1_RDNA1;
1906 mmq_y = MMQ_Y_Q4_1_RDNA1;
1907 nwarps = NWARPS_Q4_1_RDNA1;
1908 } else if (compute_capability >= VER_GEN9) {
1909 mmq_x = MMQ_X_Q4_1_AMPERE;
1910 mmq_y = MMQ_Y_Q4_1_AMPERE;
1911 nwarps = NWARPS_Q4_1_AMPERE;
1912 } else if (compute_capability >= VER_4VEC) {
1913 mmq_x = MMQ_X_Q4_1_PASCAL;
1914 mmq_y = MMQ_Y_Q4_1_PASCAL;
1915 nwarps = NWARPS_Q4_1_PASCAL;
1916 } else {
1917 GGML_ABORT("fatal error");
1918 }
1919
1920 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1921 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1922 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1923 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1924
1925 if (nrows_x % mmq_y == 0) {
1926 const bool need_check = false;
1927 /*
1928 DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
1929 the limit. To get the device limit, query
1930 info::device::max_work_group_size. Adjust the work-group size if needed.
1931 */
1932 {
1933 dpct::has_capability_or_fail(stream->get_device(),
1934 {sycl::aspect::fp16});
1935
1936 stream->submit([&](sycl::handler &cgh) {
1937 sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1938 sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1939 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1940 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1941 cgh);
1942 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1943 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1944 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1945 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1946
1947 cgh.parallel_for(
1948 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1949 [=](sycl::nd_item<3> item_ct1) {
1950 mul_mat_q4_1<need_check>(
1951 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1952 nrows_dst, item_ct1,
1953 get_pointer(tile_x_qs_q4_1_acc_ct1),
1954 get_pointer(tile_x_dm_q4_1_acc_ct1),
1955 get_pointer(tile_y_qs_acc_ct1),
1956 get_pointer(tile_y_ds_acc_ct1));
1957 });
1958 });
1959 }
1960 } else {
1961 const bool need_check = true;
1962 /*
1963 DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
1964 the limit. To get the device limit, query
1965 info::device::max_work_group_size. Adjust the work-group size if needed.
1966 */
1967 {
1968 dpct::has_capability_or_fail(stream->get_device(),
1969 {sycl::aspect::fp16});
1970
1971 stream->submit([&](sycl::handler &cgh) {
1972 sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1973 sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1974 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1975 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1976 cgh);
1977 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1978 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1979 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1980 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1981
1982 cgh.parallel_for(
1983 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1984 [=](sycl::nd_item<3> item_ct1) {
1985 mul_mat_q4_1<need_check>(
1986 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1987 nrows_dst, item_ct1,
1988 get_pointer(tile_x_qs_q4_1_acc_ct1),
1989 get_pointer(tile_x_dm_q4_1_acc_ct1),
1990 get_pointer(tile_y_qs_acc_ct1),
1991 get_pointer(tile_y_ds_acc_ct1));
1992 });
1993 });
1994 }
1995 }
1996}
1997catch (sycl::exception const &exc) {
1998 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1999 << ", line:" << __LINE__ << std::endl;
2000 std::exit(1);
2001}
2002
2003static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
2004 float *dst, const int ncols_x,
2005 const int nrows_x, const int ncols_y,
2006 const int nrows_y, const int nrows_dst,
2007 dpct::queue_ptr stream) try {
2008
2009 int id;
2010 SYCL_CHECK(
2011 CHECK_TRY_ERROR(id = get_current_device_id()));
2012 const int compute_capability = ggml_sycl_info().devices[id].cc;
2013
2014 int mmq_x, mmq_y, nwarps;
2015 if (compute_capability >= VER_GEN13) {
2016 mmq_x = MMQ_X_Q5_0_RDNA2;
2017 mmq_y = MMQ_Y_Q5_0_RDNA2;
2018 nwarps = NWARPS_Q5_0_RDNA2;
2019 } else if (compute_capability >= VER_GEN12) {
2020 mmq_x = MMQ_X_Q5_0_RDNA1;
2021 mmq_y = MMQ_Y_Q5_0_RDNA1;
2022 nwarps = NWARPS_Q5_0_RDNA1;
2023 } else if (compute_capability >= VER_GEN9) {
2024 mmq_x = MMQ_X_Q5_0_AMPERE;
2025 mmq_y = MMQ_Y_Q5_0_AMPERE;
2026 nwarps = NWARPS_Q5_0_AMPERE;
2027 } else if (compute_capability >= VER_4VEC) {
2028 mmq_x = MMQ_X_Q5_0_PASCAL;
2029 mmq_y = MMQ_Y_Q5_0_PASCAL;
2030 nwarps = NWARPS_Q5_0_PASCAL;
2031 } else {
2032 GGML_ABORT("fatal error");
2033 }
2034
2035 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2036 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2037 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2038 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2039
2040 if (nrows_x % mmq_y == 0) {
2041 const bool need_check = false;
2042 /*
2043 DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
2044 the limit. To get the device limit, query
2045 info::device::max_work_group_size. Adjust the work-group size if needed.
2046 */
2047 {
2048 dpct::has_capability_or_fail(stream->get_device(),
2049 {sycl::aspect::fp16});
2050
2051 stream->submit([&](sycl::handler &cgh) {
2052 sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2053 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2054 sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2055 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2056 cgh);
2057 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2058 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2059 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2060 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2061
2062 cgh.parallel_for(
2063 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2064 [=](sycl::nd_item<3> item_ct1) {
2065 mul_mat_q5_0<need_check>(
2066 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2067 nrows_dst, item_ct1,
2068 get_pointer(tile_x_ql_q5_0_acc_ct1),
2069 get_pointer(tile_x_d_q5_0_acc_ct1),
2070 get_pointer(tile_y_qs_acc_ct1),
2071 get_pointer(tile_y_ds_acc_ct1));
2072 });
2073 });
2074 }
2075 } else {
2076 const bool need_check = true;
2077 /*
2078 DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
2079 the limit. To get the device limit, query
2080 info::device::max_work_group_size. Adjust the work-group size if needed.
2081 */
2082 {
2083 dpct::has_capability_or_fail(stream->get_device(),
2084 {sycl::aspect::fp16});
2085
2086 stream->submit([&](sycl::handler &cgh) {
2087 sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2088 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2089 sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2090 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2091 cgh);
2092 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2093 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2094 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2095 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2096
2097 cgh.parallel_for(
2098 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2099 [=](sycl::nd_item<3> item_ct1) {
2100 mul_mat_q5_0<need_check>(
2101 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2102 nrows_dst, item_ct1,
2103 get_pointer(tile_x_ql_q5_0_acc_ct1),
2104 get_pointer(tile_x_d_q5_0_acc_ct1),
2105 get_pointer(tile_y_qs_acc_ct1),
2106 get_pointer(tile_y_ds_acc_ct1));
2107 });
2108 });
2109 }
2110 }
2111}
2112catch (sycl::exception const &exc) {
2113 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2114 << ", line:" << __LINE__ << std::endl;
2115 std::exit(1);
2116}
2117
2118static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
2119 float *dst, const int ncols_x,
2120 const int nrows_x, const int ncols_y,
2121 const int nrows_y, const int nrows_dst,
2122 dpct::queue_ptr stream) try {
2123
2124 int id;
2125 SYCL_CHECK(
2126 CHECK_TRY_ERROR(id = get_current_device_id()));
2127 const int compute_capability = ggml_sycl_info().devices[id].cc;
2128
2129 int mmq_x, mmq_y, nwarps;
2130 if (compute_capability >= VER_GEN13) {
2131 mmq_x = MMQ_X_Q5_1_RDNA2;
2132 mmq_y = MMQ_Y_Q5_1_RDNA2;
2133 nwarps = NWARPS_Q5_1_RDNA2;
2134 } else if (compute_capability >= VER_GEN12) {
2135 mmq_x = MMQ_X_Q5_1_RDNA1;
2136 mmq_y = MMQ_Y_Q5_1_RDNA1;
2137 nwarps = NWARPS_Q5_1_RDNA1;
2138 } else if (compute_capability >= VER_GEN9) {
2139 mmq_x = MMQ_X_Q5_1_AMPERE;
2140 mmq_y = MMQ_Y_Q5_1_AMPERE;
2141 nwarps = NWARPS_Q5_1_AMPERE;
2142 } else if (compute_capability >= VER_4VEC) {
2143 mmq_x = MMQ_X_Q5_1_PASCAL;
2144 mmq_y = MMQ_Y_Q5_1_PASCAL;
2145 nwarps = NWARPS_Q5_1_PASCAL;
2146 } else {
2147 GGML_ABORT("fatal error");
2148 }
2149
2150 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2151 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2152 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2153 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2154
2155 if (nrows_x % mmq_y == 0) {
2156 const bool need_check = false;
2157 /*
2158 DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
2159 the limit. To get the device limit, query
2160 info::device::max_work_group_size. Adjust the work-group size if needed.
2161 */
2162 {
2163 dpct::has_capability_or_fail(stream->get_device(),
2164 {sycl::aspect::fp16});
2165
2166 stream->submit([&](sycl::handler &cgh) {
2167 sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2168 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2169 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2170 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2171 cgh);
2172 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2173 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2174 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2175 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2176
2177 cgh.parallel_for(
2178 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2179 [=](sycl::nd_item<3> item_ct1) {
2180 mul_mat_q5_1<need_check>(
2181 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2182 nrows_dst, item_ct1,
2183 get_pointer(tile_x_ql_q5_1_acc_ct1),
2184 get_pointer(tile_x_dm_q5_1_acc_ct1),
2185 get_pointer(tile_y_qs_acc_ct1),
2186 get_pointer(tile_y_ds_acc_ct1));
2187 });
2188 });
2189 }
2190 } else {
2191 const bool need_check = true;
2192 /*
2193 DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
2194 the limit. To get the device limit, query
2195 info::device::max_work_group_size. Adjust the work-group size if needed.
2196 */
2197 {
2198 dpct::has_capability_or_fail(stream->get_device(),
2199 {sycl::aspect::fp16});
2200
2201 stream->submit([&](sycl::handler &cgh) {
2202 sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2203 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2204 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2205 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2206 cgh);
2207 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2208 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2209 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2210 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2211
2212 cgh.parallel_for(
2213 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2214 [=](sycl::nd_item<3> item_ct1) {
2215 mul_mat_q5_1<need_check>(
2216 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2217 nrows_dst, item_ct1,
2218 get_pointer(tile_x_ql_q5_1_acc_ct1),
2219 get_pointer(tile_x_dm_q5_1_acc_ct1),
2220 get_pointer(tile_y_qs_acc_ct1),
2221 get_pointer(tile_y_ds_acc_ct1));
2222 });
2223 });
2224 }
2225 }
2226}
2227catch (sycl::exception const &exc) {
2228 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2229 << ", line:" << __LINE__ << std::endl;
2230 std::exit(1);
2231}
2232
2233static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
2234 float *dst, const int ncols_x,
2235 const int nrows_x, const int ncols_y,
2236 const int nrows_y, const int nrows_dst,
2237 dpct::queue_ptr stream) try {
2238
2239 int id;
2240 SYCL_CHECK(
2241 CHECK_TRY_ERROR(id = get_current_device_id()));
2242 const int compute_capability = ggml_sycl_info().devices[id].cc;
2243
2244 int mmq_x, mmq_y, nwarps;
2245 if (compute_capability >= VER_GEN13) {
2246 mmq_x = MMQ_X_Q8_0_RDNA2;
2247 mmq_y = MMQ_Y_Q8_0_RDNA2;
2248 nwarps = NWARPS_Q8_0_RDNA2;
2249 } else if (compute_capability >= VER_GEN12) {
2250 mmq_x = MMQ_X_Q8_0_RDNA1;
2251 mmq_y = MMQ_Y_Q8_0_RDNA1;
2252 nwarps = NWARPS_Q8_0_RDNA1;
2253 } else if (compute_capability >= VER_GEN9) {
2254 mmq_x = MMQ_X_Q8_0_AMPERE;
2255 mmq_y = MMQ_Y_Q8_0_AMPERE;
2256 nwarps = NWARPS_Q8_0_AMPERE;
2257 } else if (compute_capability >= VER_4VEC) {
2258 mmq_x = MMQ_X_Q8_0_PASCAL;
2259 mmq_y = MMQ_Y_Q8_0_PASCAL;
2260 nwarps = NWARPS_Q8_0_PASCAL;
2261 } else {
2262 GGML_ABORT("fatal error");
2263 }
2264
2265 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2266 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2267 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2268 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2269
2270 if (nrows_x % mmq_y == 0) {
2271 const bool need_check = false;
2272 /*
2273 DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
2274 the limit. To get the device limit, query
2275 info::device::max_work_group_size. Adjust the work-group size if needed.
2276 */
2277 {
2278 dpct::has_capability_or_fail(stream->get_device(),
2279 {sycl::aspect::fp16});
2280
2281 stream->submit([&](sycl::handler &cgh) {
2282 sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2283 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2284 sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2285 sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2286 cgh);
2287 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2288 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2289 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2290 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2291
2292 cgh.parallel_for(
2293 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2294 [=](sycl::nd_item<3> item_ct1) {
2295 mul_mat_q8_0<need_check>(
2296 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2297 nrows_dst, item_ct1,
2298 get_pointer(tile_x_qs_q8_0_acc_ct1),
2299 get_pointer(tile_x_d_q8_0_acc_ct1),
2300 get_pointer(tile_y_qs_acc_ct1),
2301 get_pointer(tile_y_ds_acc_ct1));
2302 });
2303 });
2304 }
2305 } else {
2306 const bool need_check = true;
2307 /*
2308 DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
2309 the limit. To get the device limit, query
2310 info::device::max_work_group_size. Adjust the work-group size if needed.
2311 */
2312 {
2313 dpct::has_capability_or_fail(stream->get_device(),
2314 {sycl::aspect::fp16});
2315
2316 stream->submit([&](sycl::handler &cgh) {
2317 sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2318 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2319 sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2320 sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2321 cgh);
2322 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2323 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2324 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2325 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2326
2327 cgh.parallel_for(
2328 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2329 [=](sycl::nd_item<3> item_ct1) {
2330 mul_mat_q8_0<need_check>(
2331 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2332 nrows_dst, item_ct1,
2333 get_pointer(tile_x_qs_q8_0_acc_ct1),
2334 get_pointer(tile_x_d_q8_0_acc_ct1),
2335 get_pointer(tile_y_qs_acc_ct1),
2336 get_pointer(tile_y_ds_acc_ct1));
2337 });
2338 });
2339 }
2340 }
2341}
2342catch (sycl::exception const &exc) {
2343 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2344 << ", line:" << __LINE__ << std::endl;
2345 std::exit(1);
2346}
2347
2348static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
2349 float *dst, const int ncols_x,
2350 const int nrows_x, const int ncols_y,
2351 const int nrows_y, const int nrows_dst,
2352 dpct::queue_ptr stream) try {
2353
2354 int id;
2355 SYCL_CHECK(
2356 CHECK_TRY_ERROR(id = get_current_device_id()));
2357 const int compute_capability = ggml_sycl_info().devices[id].cc;
2358
2359 int mmq_x, mmq_y, nwarps;
2360 if (compute_capability >= VER_GEN13) {
2361 mmq_x = MMQ_X_Q2_K_RDNA2;
2362 mmq_y = MMQ_Y_Q2_K_RDNA2;
2363 nwarps = NWARPS_Q2_K_RDNA2;
2364 } else if (compute_capability >= VER_GEN12) {
2365 mmq_x = MMQ_X_Q2_K_RDNA1;
2366 mmq_y = MMQ_Y_Q2_K_RDNA1;
2367 nwarps = NWARPS_Q2_K_RDNA1;
2368 } else if (compute_capability >= VER_GEN9) {
2369 mmq_x = MMQ_X_Q2_K_AMPERE;
2370 mmq_y = MMQ_Y_Q2_K_AMPERE;
2371 nwarps = NWARPS_Q2_K_AMPERE;
2372 } else if (compute_capability >= VER_4VEC) {
2373 mmq_x = MMQ_X_Q2_K_PASCAL;
2374 mmq_y = MMQ_Y_Q2_K_PASCAL;
2375 nwarps = NWARPS_Q2_K_PASCAL;
2376 } else {
2377 GGML_ABORT("fatal error");
2378 }
2379
2380 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2381 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2382 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2383 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2384
2385 if (nrows_x % mmq_y == 0) {
2386 const bool need_check = false;
2387 /*
2388 DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
2389 the limit. To get the device limit, query
2390 info::device::max_work_group_size. Adjust the work-group size if needed.
2391 */
2392 {
2393 dpct::has_capability_or_fail(stream->get_device(),
2394 {sycl::aspect::fp16});
2395
2396 stream->submit([&](sycl::handler &cgh) {
2397 sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2398 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2399 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2400 sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2401 cgh);
2402 sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2403 sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2404 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2405 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2406 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2407 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2408
2409 cgh.parallel_for(
2410 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2411 [=](sycl::nd_item<3> item_ct1) {
2412 mul_mat_q2_K<need_check>(
2413 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2414 nrows_dst, item_ct1,
2415 get_pointer(tile_x_ql_q2_K_acc_ct1),
2416 get_pointer(tile_x_dm_q2_K_acc_ct1),
2417 get_pointer(tile_x_sc_q2_K_acc_ct1),
2418 get_pointer(tile_y_qs_acc_ct1),
2419 get_pointer(tile_y_ds_acc_ct1));
2420 });
2421 });
2422 }
2423 } else {
2424 const bool need_check = true;
2425 /*
2426 DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
2427 the limit. To get the device limit, query
2428 info::device::max_work_group_size. Adjust the work-group size if needed.
2429 */
2430 {
2431 dpct::has_capability_or_fail(stream->get_device(),
2432 {sycl::aspect::fp16});
2433
2434 stream->submit([&](sycl::handler &cgh) {
2435 sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2436 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2437 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2438 sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2439 cgh);
2440 sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2441 sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2442 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2443 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2444 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2445 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2446
2447 cgh.parallel_for(
2448 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2449 [=](sycl::nd_item<3> item_ct1) {
2450 mul_mat_q2_K<need_check>(
2451 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2452 nrows_dst, item_ct1,
2453 get_pointer(tile_x_ql_q2_K_acc_ct1),
2454 get_pointer(tile_x_dm_q2_K_acc_ct1),
2455 get_pointer(tile_x_sc_q2_K_acc_ct1),
2456 get_pointer(tile_y_qs_acc_ct1),
2457 get_pointer(tile_y_ds_acc_ct1));
2458 });
2459 });
2460 }
2461 }
2462}
2463catch (sycl::exception const &exc) {
2464 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2465 << ", line:" << __LINE__ << std::endl;
2466 std::exit(1);
2467}
2468
2469static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
2470 float *dst, const int ncols_x,
2471 const int nrows_x, const int ncols_y,
2472 const int nrows_y, const int nrows_dst,
2473 dpct::queue_ptr stream) try {
2474
2475#if QK_K == 256
2476
2477 int id;
2478 SYCL_CHECK(
2479 CHECK_TRY_ERROR(id = get_current_device_id()));
2480 const int compute_capability = ggml_sycl_info().devices[id].cc;
2481
2482 int mmq_x, mmq_y, nwarps;
2483 if (compute_capability >= VER_GEN13) {
2484 mmq_x = MMQ_X_Q3_K_RDNA2;
2485 mmq_y = MMQ_Y_Q3_K_RDNA2;
2486 nwarps = NWARPS_Q3_K_RDNA2;
2487 } else if (compute_capability >= VER_GEN12) {
2488 mmq_x = MMQ_X_Q3_K_RDNA1;
2489 mmq_y = MMQ_Y_Q3_K_RDNA1;
2490 nwarps = NWARPS_Q3_K_RDNA1;
2491 } else if (compute_capability >= VER_GEN9) {
2492 mmq_x = MMQ_X_Q3_K_AMPERE;
2493 mmq_y = MMQ_Y_Q3_K_AMPERE;
2494 nwarps = NWARPS_Q3_K_AMPERE;
2495 } else if (compute_capability >= VER_4VEC) {
2496 mmq_x = MMQ_X_Q3_K_PASCAL;
2497 mmq_y = MMQ_Y_Q3_K_PASCAL;
2498 nwarps = NWARPS_Q3_K_PASCAL;
2499 } else {
2500 GGML_ABORT("fatal error");
2501 }
2502
2503 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2504 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2505 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2506 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2507
2508 if (nrows_x % mmq_y == 0) {
2509 const bool need_check = false;
2510 /*
2511 DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
2512 the limit. To get the device limit, query
2513 info::device::max_work_group_size. Adjust the work-group size if needed.
2514 */
2515 {
2516 dpct::has_capability_or_fail(stream->get_device(),
2517 {sycl::aspect::fp16});
2518
2519 stream->submit([&](sycl::handler &cgh) {
2520 sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2521 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2522 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2523 sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2524 cgh);
2525 sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2526 sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2527 sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2528 sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2529 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2530 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2531 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2532 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2533
2534 cgh.parallel_for(
2535 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2536 [=](sycl::nd_item<3> item_ct1) {
2537 mul_mat_q3_K<need_check>(
2538 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2539 nrows_dst, item_ct1,
2540 get_pointer(tile_x_ql_q3_K_acc_ct1),
2541 get_pointer(tile_x_dm_q3_K_acc_ct1),
2542 get_pointer(tile_x_qh_q3_K_acc_ct1),
2543 get_pointer(tile_x_sc_q3_K_acc_ct1),
2544 get_pointer(tile_y_qs_acc_ct1),
2545 get_pointer(tile_y_ds_acc_ct1));
2546 });
2547 });
2548 }
2549 } else {
2550 const bool need_check = true;
2551 /*
2552 DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
2553 the limit. To get the device limit, query
2554 info::device::max_work_group_size. Adjust the work-group size if needed.
2555 */
2556 {
2557 dpct::has_capability_or_fail(stream->get_device(),
2558 {sycl::aspect::fp16});
2559
2560 stream->submit([&](sycl::handler &cgh) {
2561 sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2562 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2563 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2564 sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2565 cgh);
2566 sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2567 sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2568 sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2569 sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2570 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2571 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2572 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2573 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2574
2575 cgh.parallel_for(
2576 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2577 [=](sycl::nd_item<3> item_ct1) {
2578 mul_mat_q3_K<need_check>(
2579 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2580 nrows_dst, item_ct1,
2581 get_pointer(tile_x_ql_q3_K_acc_ct1),
2582 get_pointer(tile_x_dm_q3_K_acc_ct1),
2583 get_pointer(tile_x_qh_q3_K_acc_ct1),
2584 get_pointer(tile_x_sc_q3_K_acc_ct1),
2585 get_pointer(tile_y_qs_acc_ct1),
2586 get_pointer(tile_y_ds_acc_ct1));
2587 });
2588 });
2589 }
2590 }
2591#endif
2592}
2593catch (sycl::exception const &exc) {
2594 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2595 << ", line:" << __LINE__ << std::endl;
2596 std::exit(1);
2597}
2598
2599static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
2600 float *dst, const int ncols_x,
2601 const int nrows_x, const int ncols_y,
2602 const int nrows_y, const int nrows_dst,
2603 dpct::queue_ptr stream) try {
2604
2605 int id;
2606 SYCL_CHECK(
2607 CHECK_TRY_ERROR(id = get_current_device_id()));
2608 const int compute_capability = ggml_sycl_info().devices[id].cc;
2609
2610 int mmq_x, mmq_y, nwarps;
2611 if (compute_capability >= VER_GEN13) {
2612 mmq_x = MMQ_X_Q4_K_RDNA2;
2613 mmq_y = MMQ_Y_Q4_K_RDNA2;
2614 nwarps = NWARPS_Q4_K_RDNA2;
2615 } else if (compute_capability >= VER_GEN12) {
2616 mmq_x = MMQ_X_Q4_K_RDNA1;
2617 mmq_y = MMQ_Y_Q4_K_RDNA1;
2618 nwarps = NWARPS_Q4_K_RDNA1;
2619 } else if (compute_capability >= VER_GEN9) {
2620 mmq_x = MMQ_X_Q4_K_AMPERE;
2621 mmq_y = MMQ_Y_Q4_K_AMPERE;
2622 nwarps = NWARPS_Q4_K_AMPERE;
2623 } else if (compute_capability >= VER_4VEC) {
2624 mmq_x = MMQ_X_Q4_K_PASCAL;
2625 mmq_y = MMQ_Y_Q4_K_PASCAL;
2626 nwarps = NWARPS_Q4_K_PASCAL;
2627 } else {
2628 GGML_ABORT("fatal error");
2629 }
2630
2631 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2632 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2633 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2634 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2635
2636 if (nrows_x % mmq_y == 0) {
2637 const bool need_check = false;
2638 /*
2639 DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
2640 the limit. To get the device limit, query
2641 info::device::max_work_group_size. Adjust the work-group size if needed.
2642 */
2643 {
2644 dpct::has_capability_or_fail(stream->get_device(),
2645 {sycl::aspect::fp16});
2646
2647 stream->submit([&](sycl::handler &cgh) {
2648 sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2649 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2650 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2651 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2652 cgh);
2653 sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2654 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2655 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2656 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2657 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2658 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2659
2660 cgh.parallel_for(
2661 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2662 [=](sycl::nd_item<3> item_ct1) {
2663 mul_mat_q4_K<need_check>(
2664 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2665 nrows_dst, item_ct1,
2666 get_pointer(tile_x_ql_q4_K_acc_ct1),
2667 get_pointer(tile_x_dm_q4_K_acc_ct1),
2668 get_pointer(tile_x_sc_q4_K_acc_ct1),
2669 get_pointer(tile_y_qs_acc_ct1),
2670 get_pointer(tile_y_ds_acc_ct1));
2671 });
2672 });
2673 }
2674 } else {
2675 const bool need_check = true;
2676 /*
2677 DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
2678 the limit. To get the device limit, query
2679 info::device::max_work_group_size. Adjust the work-group size if needed.
2680 */
2681 {
2682 dpct::has_capability_or_fail(stream->get_device(),
2683 {sycl::aspect::fp16});
2684
2685 stream->submit([&](sycl::handler &cgh) {
2686 sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2687 sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2688 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2689 sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2690 cgh);
2691 sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2692 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2693 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2694 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2695 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2696 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2697
2698 cgh.parallel_for(
2699 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2700 [=](sycl::nd_item<3> item_ct1) {
2701 mul_mat_q4_K<need_check>(
2702 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2703 nrows_dst, item_ct1,
2704 get_pointer(tile_x_ql_q4_K_acc_ct1),
2705 get_pointer(tile_x_dm_q4_K_acc_ct1),
2706 get_pointer(tile_x_sc_q4_K_acc_ct1),
2707 get_pointer(tile_y_qs_acc_ct1),
2708 get_pointer(tile_y_ds_acc_ct1));
2709 });
2710 });
2711 }
2712 }
2713}
2714catch (sycl::exception const &exc) {
2715 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2716 << ", line:" << __LINE__ << std::endl;
2717 std::exit(1);
2718}
2719
2720static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
2721 float *dst, const int ncols_x,
2722 const int nrows_x, const int ncols_y,
2723 const int nrows_y, const int nrows_dst,
2724 dpct::queue_ptr stream) try {
2725
2726 int id;
2727 SYCL_CHECK(
2728 CHECK_TRY_ERROR(id = get_current_device_id()));
2729 const int compute_capability = ggml_sycl_info().devices[id].cc;
2730
2731 int mmq_x, mmq_y, nwarps;
2732 if (compute_capability >= VER_GEN13) {
2733 mmq_x = MMQ_X_Q5_K_RDNA2;
2734 mmq_y = MMQ_Y_Q5_K_RDNA2;
2735 nwarps = NWARPS_Q5_K_RDNA2;
2736 } else if (compute_capability >= VER_GEN12) {
2737 mmq_x = MMQ_X_Q5_K_RDNA1;
2738 mmq_y = MMQ_Y_Q5_K_RDNA1;
2739 nwarps = NWARPS_Q5_K_RDNA1;
2740 } else if (compute_capability >= VER_GEN9) {
2741 mmq_x = MMQ_X_Q5_K_AMPERE;
2742 mmq_y = MMQ_Y_Q5_K_AMPERE;
2743 nwarps = NWARPS_Q5_K_AMPERE;
2744 } else if (compute_capability >= VER_4VEC) {
2745 mmq_x = MMQ_X_Q5_K_PASCAL;
2746 mmq_y = MMQ_Y_Q5_K_PASCAL;
2747 nwarps = NWARPS_Q5_K_PASCAL;
2748 } else {
2749 GGML_ABORT("fatal error");
2750 }
2751
2752 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2753 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2754 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2755 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2756
2757 if (nrows_x % mmq_y == 0) {
2758 const bool need_check = false;
2759 /*
2760 DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
2761 the limit. To get the device limit, query
2762 info::device::max_work_group_size. Adjust the work-group size if needed.
2763 */
2764 {
2765 dpct::has_capability_or_fail(stream->get_device(),
2766 {sycl::aspect::fp16});
2767
2768 stream->submit([&](sycl::handler &cgh) {
2769 sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2770 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2771 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2772 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2773 cgh);
2774 sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2775 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2776 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2777 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2778 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2779 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2780
2781 cgh.parallel_for(
2782 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2783 [=](sycl::nd_item<3> item_ct1) {
2784 mul_mat_q5_K<need_check>(
2785 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2786 nrows_dst, item_ct1,
2787 get_pointer(tile_x_ql_q5_K_acc_ct1),
2788 get_pointer(tile_x_dm_q5_K_acc_ct1),
2789 get_pointer(tile_x_sc_q5_K_acc_ct1),
2790 get_pointer(tile_y_qs_acc_ct1),
2791 get_pointer(tile_y_ds_acc_ct1));
2792 });
2793 });
2794 }
2795 } else {
2796 const bool need_check = true;
2797 /*
2798 DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
2799 the limit. To get the device limit, query
2800 info::device::max_work_group_size. Adjust the work-group size if needed.
2801 */
2802 {
2803 dpct::has_capability_or_fail(stream->get_device(),
2804 {sycl::aspect::fp16});
2805
2806 stream->submit([&](sycl::handler &cgh) {
2807 sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2808 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2809 sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2810 sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2811 cgh);
2812 sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2813 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2814 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2815 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2816 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2817 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2818
2819 cgh.parallel_for(
2820 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2821 [=](sycl::nd_item<3> item_ct1) {
2822 mul_mat_q5_K<need_check>(
2823 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2824 nrows_dst, item_ct1,
2825 get_pointer(tile_x_ql_q5_K_acc_ct1),
2826 get_pointer(tile_x_dm_q5_K_acc_ct1),
2827 get_pointer(tile_x_sc_q5_K_acc_ct1),
2828 get_pointer(tile_y_qs_acc_ct1),
2829 get_pointer(tile_y_ds_acc_ct1));
2830 });
2831 });
2832 }
2833 }
2834}
2835catch (sycl::exception const &exc) {
2836 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2837 << ", line:" << __LINE__ << std::endl;
2838 std::exit(1);
2839}
2840
2841static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
2842 float *dst, const int ncols_x,
2843 const int nrows_x, const int ncols_y,
2844 const int nrows_y, const int nrows_dst,
2845 dpct::queue_ptr stream) try {
2846
2847 int id;
2848 SYCL_CHECK(
2849 CHECK_TRY_ERROR(id = get_current_device_id()));
2850 const int compute_capability = ggml_sycl_info().devices[id].cc;
2851
2852 int mmq_x, mmq_y, nwarps;
2853 if (compute_capability >= VER_GEN13) {
2854 mmq_x = MMQ_X_Q6_K_RDNA2;
2855 mmq_y = MMQ_Y_Q6_K_RDNA2;
2856 nwarps = NWARPS_Q6_K_RDNA2;
2857 } else if (compute_capability >= VER_GEN12) {
2858 mmq_x = MMQ_X_Q6_K_RDNA1;
2859 mmq_y = MMQ_Y_Q6_K_RDNA1;
2860 nwarps = NWARPS_Q6_K_RDNA1;
2861 } else if (compute_capability >= VER_GEN9) {
2862 mmq_x = MMQ_X_Q6_K_AMPERE;
2863 mmq_y = MMQ_Y_Q6_K_AMPERE;
2864 nwarps = NWARPS_Q6_K_AMPERE;
2865 } else if (compute_capability >= VER_4VEC) {
2866 mmq_x = MMQ_X_Q6_K_PASCAL;
2867 mmq_y = MMQ_Y_Q6_K_PASCAL;
2868 nwarps = NWARPS_Q6_K_PASCAL;
2869 } else {
2870 GGML_ABORT("fatal error");
2871 }
2872
2873 const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2874 const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2875 const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2876 const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2877
2878 if (nrows_x % mmq_y == 0) {
2879 const bool need_check = false;
2880 /*
2881 DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
2882 the limit. To get the device limit, query
2883 info::device::max_work_group_size. Adjust the work-group size if needed.
2884 */
2885 {
2886 dpct::has_capability_or_fail(stream->get_device(),
2887 {sycl::aspect::fp16});
2888
2889 stream->submit([&](sycl::handler &cgh) {
2890 sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2891 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2892 sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2893 sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2894 cgh);
2895 sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2896 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2897 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2898 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2899 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2900 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2901
2902 cgh.parallel_for(
2903 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2904 [=](sycl::nd_item<3> item_ct1) {
2905 mul_mat_q6_K<need_check>(
2906 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2907 nrows_dst, item_ct1,
2908 get_pointer(tile_x_ql_acc_ct1),
2909 get_pointer(tile_x_dm_acc_ct1),
2910 get_pointer(tile_x_sc_acc_ct1),
2911 get_pointer(tile_y_qs_acc_ct1),
2912 get_pointer(tile_y_ds_acc_ct1));
2913 });
2914 });
2915 }
2916 } else {
2917 const bool need_check = true;
2918 /*
2919 DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
2920 the limit. To get the device limit, query
2921 info::device::max_work_group_size. Adjust the work-group size if needed.
2922 */
2923 {
2924 dpct::has_capability_or_fail(stream->get_device(),
2925 {sycl::aspect::fp16});
2926
2927 stream->submit([&](sycl::handler &cgh) {
2928 sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2929 sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2930 sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2931 sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2932 cgh);
2933 sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2934 sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2935 sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2936 sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2937 sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2938 sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2939
2940 cgh.parallel_for(
2941 sycl::nd_range<3>(block_nums * block_dims, block_dims),
2942 [=](sycl::nd_item<3> item_ct1) {
2943 mul_mat_q6_K<need_check>(
2944 vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2945 nrows_dst, item_ct1,
2946 get_pointer(tile_x_ql_acc_ct1),
2947 get_pointer(tile_x_dm_acc_ct1),
2948 get_pointer(tile_x_sc_acc_ct1),
2949 get_pointer(tile_y_qs_acc_ct1),
2950 get_pointer(tile_y_ds_acc_ct1));
2951 });
2952 });
2953 }
2954 }
2955}
2956catch (sycl::exception const &exc) {
2957 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2958 << ", line:" << __LINE__ << std::endl;
2959 std::exit(1);
2960}
2961
2962void ggml_sycl_op_mul_mat_q(
2963 ggml_backend_sycl_context & ctx,
2964 const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2965 const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2966 float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2967 const int64_t src1_ncols, const int64_t src1_padded_row_size,
2968 const dpct::queue_ptr &stream) try {
2969
2970 const int64_t ne00 = src0->ne[0];
2971
2972 const int64_t ne10 = src1->ne[0];
2973 GGML_ASSERT(ne10 % QK8_1 == 0);
2974
2975 const int64_t ne0 = dst->ne[0];
2976
2977 const int64_t row_diff = row_high - row_low;
2978
2979 int device_id;
2980 SYCL_CHECK(
2981 CHECK_TRY_ERROR(device_id = get_current_device_id()));
2982
2983 // the main device has a larger memory buffer to hold the results from all GPUs
2984 // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
2985 const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
2986
2987 switch (src0->type) {
2988 case GGML_TYPE_Q4_0:
2989 ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2990 break;
2991 case GGML_TYPE_Q4_1:
2992 ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2993 break;
2994 case GGML_TYPE_Q5_0:
2995 ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2996 break;
2997 case GGML_TYPE_Q5_1:
2998 ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2999 break;
3000 case GGML_TYPE_Q8_0:
3001 ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3002 break;
3003 case GGML_TYPE_Q2_K:
3004 ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3005 break;
3006 case GGML_TYPE_Q3_K:
3007 ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3008 break;
3009 case GGML_TYPE_Q4_K:
3010 ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3011 break;
3012 case GGML_TYPE_Q5_K:
3013 ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3014 break;
3015 case GGML_TYPE_Q6_K:
3016 ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3017 break;
3018 default:
3019 GGML_ABORT("fatal error");
3020 }
3021
3022 GGML_UNUSED(src1);
3023 GGML_UNUSED(dst);
3024 GGML_UNUSED(src1_ddf_i);
3025}
3026catch (sycl::exception const &exc) {
3027 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3028 << ", line:" << __LINE__ << std::endl;
3029 std::exit(1);
3030}