1#include "convert.hpp"
2#include "dequantize.hpp"
3#include "presets.hpp"
4
5#if defined(__INTEL_LLVM_COMPILER)
6 #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
7 #include <sycl/ext/oneapi/bfloat16.hpp>
8 #define GGML_SYCL_HAS_BF16
9 #endif
10#endif
11
12template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
13static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
14 const sycl::nd_item<3> &item_ct1) {
15 const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
16 item_ct1.get_local_id(2));
17
18 if (i >= k) {
19 return;
20 }
21
22 const int64_t ib = i/qk; // block index
23 const int64_t iqs = (i%qk)/qr; // quant index
24 const int64_t iybs = i - i%qk; // y block start index
25 const int64_t y_offset = qr == 1 ? 1 : qk/2;
26
27 // dequantize
28 dfloat2 v;
29 dequantize_kernel(vx, ib, iqs, v);
30
31 y[iybs + iqs + 0] = v.x();
32 y[iybs + iqs + y_offset] = v.y();
33}
34
35template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
36static void dequantize_block_sycl(const void *__restrict__ vx,
37 dst_t *__restrict__ y, const int64_t k,
38 dpct::queue_ptr stream) {
39 const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
40 {
41 dpct::has_capability_or_fail(stream->get_device(),
42 {sycl::aspect::fp16});
43 stream->parallel_for(
44 sycl::nd_range<3>(
45 sycl::range<3>(1, 1, num_blocks) *
46 sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
47 sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
48 [=](sycl::nd_item<3> item_ct1) {
49 dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
50 });
51 }
52}
53
54template <typename dst_t>
55static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
56 dpct::queue_ptr stream) {
57 const int64_t nb = k / QK_K;
58#if QK_K == 256
59 {
60 dpct::has_capability_or_fail(stream->get_device(),
61 {sycl::aspect::fp16});
62
63 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
64 sycl::range<3>(1, 1, 64),
65 sycl::range<3>(1, 1, 64)),
66 [=](sycl::nd_item<3> item_ct1) {
67 dequantize_block_q2_K(vx, y, item_ct1);
68 });
69 }
70#else
71 {
72 dpct::has_capability_or_fail(stream->get_device(),
73 {sycl::aspect::fp16});
74
75 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
76 sycl::range<3>(1, 1, 32),
77 sycl::range<3>(1, 1, 32)),
78 [=](sycl::nd_item<3> item_ct1) {
79 dequantize_block_q2_K(vx, y, item_ct1);
80 });
81 }
82
83#endif
84}
85
86template <typename dst_t>
87static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
88 dpct::queue_ptr stream) {
89 const int64_t nb = k / QK_K;
90#if QK_K == 256
91 {
92 dpct::has_capability_or_fail(stream->get_device(),
93 {sycl::aspect::fp16});
94
95 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
96 sycl::range<3>(1, 1, 64),
97 sycl::range<3>(1, 1, 64)),
98 [=](sycl::nd_item<3> item_ct1) {
99 dequantize_block_q3_K(vx, y, item_ct1);
100 });
101 }
102#else
103 {
104 dpct::has_capability_or_fail(stream->get_device(),
105 {sycl::aspect::fp16});
106
107 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
108 sycl::range<3>(1, 1, 32),
109 sycl::range<3>(1, 1, 32)),
110 [=](sycl::nd_item<3> item_ct1) {
111 dequantize_block_q3_K(vx, y, item_ct1);
112 });
113 }
114#endif
115}
116
117template <typename dst_t>
118static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
119 dpct::queue_ptr stream) {
120 const int64_t nb32 = k / 32;
121 const int64_t nb = (k + 255) / 256;
122 {
123 dpct::has_capability_or_fail(stream->get_device(),
124 {sycl::aspect::fp16});
125
126 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
127 sycl::range<3>(1, 1, 32),
128 sycl::range<3>(1, 1, 32)),
129 [=](sycl::nd_item<3> item_ct1) {
130 dequantize_block_q4_0(vx, y, nb32, item_ct1);
131 });
132 }
133}
134
135template <typename dst_t>
136static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
137 dpct::queue_ptr stream) {
138
139 dpct::has_capability_or_fail(stream->get_device(),
140 {sycl::aspect::fp16});
141
142 int constexpr WARP_K = WARP_SIZE * QK4_0;
143 const int n_warp = (k + WARP_K - 1) / WARP_K;
144 GGML_ASSERT(k % 2 == 0);
145 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
146 sycl::range<3>(1, 1, WARP_SIZE),
147 sycl::range<3>(1, 1, WARP_SIZE)),
148 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
149 dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
150 });
151
152}
153
154template <typename dst_t>
155static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
156 dpct::queue_ptr stream) {
157 const int64_t nb32 = k / 32;
158 const int64_t nb = (k + 255) / 256;
159 {
160 dpct::has_capability_or_fail(stream->get_device(),
161 {sycl::aspect::fp16});
162
163 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
164 sycl::range<3>(1, 1, 32),
165 sycl::range<3>(1, 1, 32)),
166 [=](sycl::nd_item<3> item_ct1) {
167 dequantize_block_q4_1(vx, y, nb32, item_ct1);
168 });
169 }
170}
171
172
173template <typename dst_t>
174static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
175 dpct::queue_ptr stream) {
176 const int64_t nb = k / QK_K;
177 {
178 dpct::has_capability_or_fail(stream->get_device(),
179 {sycl::aspect::fp16});
180
181 stream->submit([&](sycl::handler &cgh) {
182 sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
183 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
184 sycl::range<3>(1, 1, 32),
185 sycl::range<3>(1, 1, 32)),
186 [=](sycl::nd_item<3> item_ct1) {
187 dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
188 });
189 });
190 }
191}
192
193template <typename dst_t>
194static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
195 const int64_t nb = k / QK_K;
196 const size_t local_size = 32;
197 const size_t global_size = nb * local_size;
198
199 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
200
201 stream->submit([&](sycl::handler & cgh) {
202 sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
203
204 cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
205 [=](sycl::nd_item<1> item_ct1) {
206 dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
207 });
208 });
209}
210
211template <typename dst_t>
212static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
213 dpct::queue_ptr stream) {
214 const int64_t nb = k / QK_K;
215#if QK_K == 256
216 {
217 dpct::has_capability_or_fail(stream->get_device(),
218 {sycl::aspect::fp16});
219
220 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
221 sycl::range<3>(1, 1, 64),
222 sycl::range<3>(1, 1, 64)),
223 [=](sycl::nd_item<3> item_ct1) {
224 dequantize_block_q5_K(vx, y, item_ct1);
225 });
226 }
227#else
228 {
229 dpct::has_capability_or_fail(stream->get_device(),
230 {sycl::aspect::fp16});
231
232 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
233 sycl::range<3>(1, 1, 32),
234 sycl::range<3>(1, 1, 32)),
235 [=](sycl::nd_item<3> item_ct1) {
236 dequantize_block_q5_K(vx, y, item_ct1);
237 });
238 }
239
240#endif
241}
242
243template <typename dst_t>
244static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
245 dpct::queue_ptr stream) {
246 const int64_t nb = k / QK_K;
247#if QK_K == 256
248 {
249 dpct::has_capability_or_fail(stream->get_device(),
250 {sycl::aspect::fp16});
251
252 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
253 sycl::range<3>(1, 1, 64),
254 sycl::range<3>(1, 1, 64)),
255 [=](sycl::nd_item<3> item_ct1) {
256 dequantize_block_q6_K(vx, y, item_ct1);
257 });
258 }
259#else
260 {
261 dpct::has_capability_or_fail(stream->get_device(),
262 {sycl::aspect::fp16});
263
264 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
265 sycl::range<3>(1, 1, 32),
266 sycl::range<3>(1, 1, 32)),
267 [=](sycl::nd_item<3> item_ct1) {
268 dequantize_block_q6_K(vx, y, item_ct1);
269 });
270 }
271
272#endif
273}
274
275template <typename dst_t>
276static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
277 const int64_t nb = k / QK_K;
278
279 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
280
281 stream->parallel_for(
282 sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
283 [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
284}
285
286template <typename dst_t>
287static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
288 dpct::queue_ptr stream) {
289 const int64_t nb = k / QK_K;
290 {
291 dpct::has_capability_or_fail(stream->get_device(),
292 {sycl::aspect::fp16});
293
294 stream->submit([&](sycl::handler &cgh) {
295 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
296 sycl::range<3>(1, 1, 32),
297 sycl::range<3>(1, 1, 32)),
298 [=](sycl::nd_item<3> item_ct1) {
299 dequantize_block_iq1_s(
300 vx, y, item_ct1, iq1s_grid_gpu
301 );
302 });
303 });
304 }
305}
306
307template <typename dst_t>
308static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
309 dpct::queue_ptr stream) {
310 const int64_t nb = k / QK_K;
311 {
312 dpct::has_capability_or_fail(stream->get_device(),
313 {sycl::aspect::fp16});
314
315 stream->submit([&](sycl::handler &cgh) {
316 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
317 sycl::range<3>(1, 1, 32),
318 sycl::range<3>(1, 1, 32)),
319 [=](sycl::nd_item<3> item_ct1) {
320 dequantize_block_iq1_m(
321 vx, y, item_ct1, iq1s_grid_gpu
322 );
323 });
324 });
325 }
326}
327
328template <typename dst_t>
329static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
330 dpct::queue_ptr stream) {
331 const int64_t nb = k / QK_K;
332 {
333 dpct::has_capability_or_fail(stream->get_device(),
334 {sycl::aspect::fp16});
335
336 stream->submit([&](sycl::handler &cgh) {
337 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
338 sycl::range<3>(1, 1, 32),
339 sycl::range<3>(1, 1, 32)),
340 [=](sycl::nd_item<3> item_ct1) {
341 dequantize_block_iq2_xxs(
342 vx, y, item_ct1, iq2xxs_grid,
343 ksigns_iq2xs, kmask_iq2xs);
344 });
345 });
346 }
347}
348
349template <typename dst_t>
350static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
351 dpct::queue_ptr stream) {
352 const int64_t nb = k / QK_K;
353 {
354 dpct::has_capability_or_fail(stream->get_device(),
355 {sycl::aspect::fp16});
356
357 stream->submit([&](sycl::handler &cgh) {
358 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
359 sycl::range<3>(1, 1, 32),
360 sycl::range<3>(1, 1, 32)),
361 [=](sycl::nd_item<3> item_ct1) {
362 dequantize_block_iq2_xs(
363 vx, y, item_ct1, iq2xs_grid,
364 ksigns_iq2xs, kmask_iq2xs);
365 });
366 });
367 }
368}
369
370template <typename dst_t>
371static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
372 dpct::queue_ptr stream) {
373 const int64_t nb = k / QK_K;
374 {
375 dpct::has_capability_or_fail(stream->get_device(),
376 {sycl::aspect::fp16});
377
378 stream->submit([&](sycl::handler &cgh) {
379 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
380 sycl::range<3>(1, 1, 32),
381 sycl::range<3>(1, 1, 32)),
382 [=](sycl::nd_item<3> item_ct1) {
383 dequantize_block_iq2_s(vx, y, item_ct1);
384 });
385 });
386 }
387}
388
389
390template <typename dst_t>
391static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
392 dpct::queue_ptr stream) {
393 const int64_t nb = k / QK_K;
394 {
395 dpct::has_capability_or_fail(stream->get_device(),
396 {sycl::aspect::fp16});
397
398 stream->submit([&](sycl::handler &cgh) {
399 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
400 sycl::range<3>(1, 1, 32),
401 sycl::range<3>(1, 1, 32)),
402 [=](sycl::nd_item<3> item_ct1) {
403 dequantize_block_iq3_xxs(
404 vx, y, item_ct1, iq3xxs_grid,
405 ksigns_iq2xs, kmask_iq2xs);
406 });
407 });
408 }
409}
410
411template <typename dst_t>
412static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
413 dpct::queue_ptr stream) {
414 const int64_t nb = k / QK_K;
415 {
416 dpct::has_capability_or_fail(stream->get_device(),
417 {sycl::aspect::fp16});
418
419 stream->submit([&](sycl::handler &cgh) {
420 cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
421 sycl::range<3>(1, 1, 32),
422 sycl::range<3>(1, 1, 32)),
423 [=](sycl::nd_item<3> item_ct1) {
424 dequantize_block_iq3_s(
425 vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
426 });
427 });
428 }
429}
430
431template <typename dst_t>
432static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
433 dpct::queue_ptr stream) {
434 const int64_t nb = (k + QK_K - 1) / QK_K;
435#if QK_K == 64
436 dequantize_row_iq4_nl_sycl(vx, y, k, stream);
437#else
438 {
439 dpct::has_capability_or_fail(stream->get_device(),
440 {sycl::aspect::fp16});
441
442 stream->submit([&](sycl::handler &cgh) {
443 cgh.parallel_for(
444 sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
445 sycl::range<3>(1, 1, 32),
446 sycl::range<3>(1, 1, 32)),
447 [=](sycl::nd_item<3> item_ct1) {
448 dequantize_block_iq4_xs(vx, y, item_ct1);
449 });
450 });
451 }
452#endif
453}
454
455template <typename dst_t>
456static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
457 dpct::queue_ptr stream) {
458 const int64_t nb = (k + QK_K - 1) / QK_K;
459 {
460 dpct::has_capability_or_fail(stream->get_device(),
461 {sycl::aspect::fp16});
462
463 stream->submit([&](sycl::handler &cgh) {
464 cgh.parallel_for(
465 sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
466 sycl::range<3>(1, 1, 32),
467 sycl::range<3>(1, 1, 32)),
468 [=](sycl::nd_item<3> item_ct1) {
469 dequantize_block_iq4_nl(vx, y, item_ct1);
470 });
471 });
472 }
473}
474
475template <typename dst_t>
476static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
477 const int nb = (k + QK_K - 1) / QK_K;
478 stream->parallel_for(
479 sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
480 [=](sycl::nd_item<3> item_ct1) {
481 dequantize_block_mxfp4(vx, y, item_ct1);
482 });
483}
484
485template <typename src_t, typename dst_t>
486static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
487 const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
488 const sycl::nd_item<3> & item_ct1) {
489
490 const int64_t work_group_size = item_ct1.get_local_range(2);
491 const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
492
493 const int64_t i01 = item_ct1.get_group(1);
494 const int64_t i02 = item_ct1.get_group(0) % ne02;
495 const int64_t i03 = item_ct1.get_group(0) / ne02;
496
497 // make each work-item deal with more elements since sycl global range can not exceed max int
498 const src_t * x = static_cast<const src_t *>(vx);
499 const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
500 const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
501
502#pragma unroll
503 for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
504 y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
505 }
506}
507
508template <typename src_t, typename dst_t>
509static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
510 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
511 const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
512 dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
513
514 sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
515
516 // decrease global range when it exceeds the max int
517 // TODO: Downsample logic is separated from the kernel, a rewrite is desirable
518 int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
519 sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
520
521 queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
522 convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
523 });
524}
525
526template <typename src_t, typename dst_t>
527static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
528 convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
529}
530
531
532to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
533 switch (type) {
534 case GGML_TYPE_Q4_0:
535 if (dst->src[0]->extra &&
536 ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
537 return dequantize_row_q4_0_sycl_reorder;
538 } else {
539 return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
540 }
541 case GGML_TYPE_Q4_1:
542 return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
543 case GGML_TYPE_Q5_0:
544 return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
545 case GGML_TYPE_Q5_1:
546 return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
547 case GGML_TYPE_Q8_0:
548 return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
549 case GGML_TYPE_Q2_K:
550 return dequantize_row_q2_K_sycl;
551 case GGML_TYPE_Q3_K:
552 return dequantize_row_q3_K_sycl;
553 case GGML_TYPE_Q4_K:
554 if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
555 return dequantize_row_q4_K_sycl_reorder;
556 } else {
557 return dequantize_row_q4_K_sycl;
558 }
559 case GGML_TYPE_Q5_K:
560 return dequantize_row_q5_K_sycl;
561 case GGML_TYPE_Q6_K:
562 if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
563 return dequantize_row_q6_K_sycl_reorder;
564 } else {
565 return dequantize_row_q6_K_sycl;
566 }
567 case GGML_TYPE_IQ1_S:
568 return dequantize_row_iq1_s_sycl;
569 case GGML_TYPE_IQ1_M:
570 return dequantize_row_iq1_m_sycl;
571 case GGML_TYPE_IQ2_XXS:
572 return dequantize_row_iq2_xxs_sycl;
573 case GGML_TYPE_IQ2_XS:
574 return dequantize_row_iq2_xs_sycl;
575 case GGML_TYPE_IQ2_S:
576 return dequantize_row_iq2_s_sycl;
577 case GGML_TYPE_IQ3_XXS:
578 return dequantize_row_iq3_xxs_sycl;
579 case GGML_TYPE_IQ3_S:
580 return dequantize_row_iq3_s_sycl;
581 case GGML_TYPE_IQ4_XS:
582 return dequantize_row_iq4_xs_sycl;
583 case GGML_TYPE_IQ4_NL:
584 return dequantize_row_iq4_nl_sycl;
585 case GGML_TYPE_MXFP4:
586 return dequantize_row_mxfp4_sycl;
587 case GGML_TYPE_F32:
588 return convert_unary_sycl<float>;
589#ifdef GGML_SYCL_HAS_BF16
590 case GGML_TYPE_BF16:
591 return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
592#endif
593 default:
594 return nullptr;
595 }
596}
597
598to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
599 switch (type) {
600 case GGML_TYPE_Q4_0:
601 if (dst->src[0]->extra &&
602 ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
603 return dequantize_row_q4_0_sycl_reorder;
604 } else {
605 return dequantize_row_q4_0_sycl;
606 }
607 case GGML_TYPE_Q4_1:
608 return dequantize_row_q4_1_sycl;
609 case GGML_TYPE_Q5_0:
610 return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
611 case GGML_TYPE_Q5_1:
612 return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
613 case GGML_TYPE_Q8_0:
614 return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
615 case GGML_TYPE_Q2_K:
616 return dequantize_row_q2_K_sycl;
617 case GGML_TYPE_Q3_K:
618 return dequantize_row_q3_K_sycl;
619 case GGML_TYPE_Q4_K:
620 if (dst->src[0]->extra &&
621 ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
622 return dequantize_row_q4_K_sycl_reorder;
623 } else {
624 return dequantize_row_q4_K_sycl;
625 }
626 case GGML_TYPE_Q5_K:
627 return dequantize_row_q5_K_sycl;
628 case GGML_TYPE_Q6_K:
629 if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
630 return dequantize_row_q6_K_sycl_reorder;
631 } else {
632 return dequantize_row_q6_K_sycl;
633 }
634 case GGML_TYPE_IQ1_S:
635 return dequantize_row_iq1_s_sycl;
636 case GGML_TYPE_IQ1_M:
637 return dequantize_row_iq1_m_sycl;
638 case GGML_TYPE_IQ2_XXS:
639 return dequantize_row_iq2_xxs_sycl;
640 case GGML_TYPE_IQ2_XS:
641 return dequantize_row_iq2_xs_sycl;
642 case GGML_TYPE_IQ2_S:
643 return dequantize_row_iq2_s_sycl;
644 case GGML_TYPE_IQ3_XXS:
645 return dequantize_row_iq3_xxs_sycl;
646 case GGML_TYPE_IQ3_S:
647 return dequantize_row_iq3_s_sycl;
648 case GGML_TYPE_IQ4_XS:
649 return dequantize_row_iq4_xs_sycl;
650 case GGML_TYPE_IQ4_NL:
651 return dequantize_row_iq4_nl_sycl;
652 case GGML_TYPE_MXFP4:
653 return dequantize_row_mxfp4_sycl;
654 case GGML_TYPE_F16:
655 return convert_unary_sycl<sycl::half>;
656#ifdef GGML_SYCL_HAS_BF16
657 case GGML_TYPE_BF16:
658 return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
659#endif
660 default:
661 return nullptr;
662 }
663}
664
665to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
666 switch (type) {
667 case GGML_TYPE_F32:
668 return convert_unary_nc_sycl<float>;
669#ifdef GGML_SYCL_HAS_BF16
670 case GGML_TYPE_BF16:
671 return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
672#endif
673 default:
674 return nullptr;
675 }
676}