1#include "convert.hpp"
  2#include "dequantize.hpp"
  3#include "presets.hpp"
  4
  5#if defined(__INTEL_LLVM_COMPILER)
  6    #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
  7        #include <sycl/ext/oneapi/bfloat16.hpp>
  8        #define GGML_SYCL_HAS_BF16
  9    #endif
 10#endif
 11
 12template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 13static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
 14                             const sycl::nd_item<3> &item_ct1) {
 15    const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
 16                       item_ct1.get_local_id(2));
 17
 18    if (i >= k) {
 19        return;
 20    }
 21
 22    const int64_t ib = i/qk; // block index
 23    const int64_t iqs = (i%qk)/qr; // quant index
 24    const int64_t iybs = i - i%qk; // y block start index
 25    const int64_t y_offset = qr == 1 ? 1 : qk/2;
 26
 27    // dequantize
 28    dfloat2 v;
 29    dequantize_kernel(vx, ib, iqs, v);
 30
 31    y[iybs + iqs + 0] = v.x();
 32    y[iybs + iqs + y_offset] = v.y();
 33}
 34
 35template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 36static void dequantize_block_sycl(const void *__restrict__ vx,
 37                                  dst_t *__restrict__ y, const int64_t k,
 38                                  dpct::queue_ptr stream) {
 39    const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
 40    {
 41        dpct::has_capability_or_fail(stream->get_device(),
 42                                     {sycl::aspect::fp16});
 43        stream->parallel_for(
 44            sycl::nd_range<3>(
 45                sycl::range<3>(1, 1, num_blocks) *
 46                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
 47                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
 48            [=](sycl::nd_item<3> item_ct1) {
 49                dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
 50            });
 51    }
 52}
 53
 54template <typename dst_t>
 55static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
 56                                     dpct::queue_ptr stream) {
 57    const int64_t nb = k / QK_K;
 58#if QK_K == 256
 59    {
 60        dpct::has_capability_or_fail(stream->get_device(),
 61                                     {sycl::aspect::fp16});
 62
 63        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
 64                                                   sycl::range<3>(1, 1, 64),
 65                                               sycl::range<3>(1, 1, 64)),
 66                             [=](sycl::nd_item<3> item_ct1) {
 67                                 dequantize_block_q2_K(vx, y, item_ct1);
 68                             });
 69    }
 70#else
 71    {
 72        dpct::has_capability_or_fail(stream->get_device(),
 73                                     {sycl::aspect::fp16});
 74
 75        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
 76                                                   sycl::range<3>(1, 1, 32),
 77                                               sycl::range<3>(1, 1, 32)),
 78                             [=](sycl::nd_item<3> item_ct1) {
 79                                 dequantize_block_q2_K(vx, y, item_ct1);
 80                             });
 81    }
 82
 83#endif
 84}
 85
 86template <typename dst_t>
 87static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
 88                                     dpct::queue_ptr stream) {
 89    const int64_t nb = k / QK_K;
 90#if QK_K == 256
 91    {
 92        dpct::has_capability_or_fail(stream->get_device(),
 93                                     {sycl::aspect::fp16});
 94
 95        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
 96                                                   sycl::range<3>(1, 1, 64),
 97                                               sycl::range<3>(1, 1, 64)),
 98                             [=](sycl::nd_item<3> item_ct1) {
 99                                 dequantize_block_q3_K(vx, y, item_ct1);
100                             });
101    }
102#else
103    {
104        dpct::has_capability_or_fail(stream->get_device(),
105                                     {sycl::aspect::fp16});
106
107        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
108                                                   sycl::range<3>(1, 1, 32),
109                                               sycl::range<3>(1, 1, 32)),
110                             [=](sycl::nd_item<3> item_ct1) {
111                                 dequantize_block_q3_K(vx, y, item_ct1);
112                             });
113    }
114#endif
115}
116
117template <typename dst_t>
118static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
119                                     dpct::queue_ptr stream) {
120    const int64_t nb32 = k / 32;
121    const int64_t nb = (k + 255) / 256;
122    {
123        dpct::has_capability_or_fail(stream->get_device(),
124                                     {sycl::aspect::fp16});
125
126        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
127                                                   sycl::range<3>(1, 1, 32),
128                                               sycl::range<3>(1, 1, 32)),
129                             [=](sycl::nd_item<3> item_ct1) {
130                                 dequantize_block_q4_0(vx, y, nb32, item_ct1);
131                             });
132    }
133}
134
135template <typename dst_t>
136static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
137                                     dpct::queue_ptr stream) {
138
139    dpct::has_capability_or_fail(stream->get_device(),
140                                    {sycl::aspect::fp16});
141
142    int constexpr WARP_K = WARP_SIZE * QK4_0;
143    const int n_warp = (k + WARP_K - 1) / WARP_K;
144    GGML_ASSERT(k % 2 == 0);
145    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
146        sycl::range<3>(1, 1, WARP_SIZE),
147        sycl::range<3>(1, 1, WARP_SIZE)),
148        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
149            dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
150        });
151
152}
153
154template <typename dst_t>
155static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
156                                     dpct::queue_ptr stream) {
157    const int64_t nb32 = k / 32;
158    const int64_t nb = (k + 255) / 256;
159    {
160        dpct::has_capability_or_fail(stream->get_device(),
161                                     {sycl::aspect::fp16});
162
163        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
164                                                   sycl::range<3>(1, 1, 32),
165                                               sycl::range<3>(1, 1, 32)),
166                             [=](sycl::nd_item<3> item_ct1) {
167                                 dequantize_block_q4_1(vx, y, nb32, item_ct1);
168                             });
169    }
170}
171
172
173template <typename dst_t>
174static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
175                                     dpct::queue_ptr stream) {
176    const int64_t nb = k / QK_K;
177    {
178        dpct::has_capability_or_fail(stream->get_device(),
179                                     {sycl::aspect::fp16});
180
181        stream->submit([&](sycl::handler &cgh) {
182            sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
183            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
184                                                   sycl::range<3>(1, 1, 32),
185                                               sycl::range<3>(1, 1, 32)),
186                             [=](sycl::nd_item<3> item_ct1) {
187                                 dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
188                             });
189        });
190    }
191}
192
193template <typename dst_t>
194static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
195    const int64_t nb = k / QK_K;
196    const size_t  local_size  = 32;
197    const size_t  global_size = nb * local_size;
198
199    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
200
201    stream->submit([&](sycl::handler & cgh) {
202        sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
203
204        cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
205                         [=](sycl::nd_item<1> item_ct1) {
206                             dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
207                         });
208    });
209}
210
211template <typename dst_t>
212static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
213                                     dpct::queue_ptr stream) {
214    const int64_t nb = k / QK_K;
215#if QK_K == 256
216    {
217        dpct::has_capability_or_fail(stream->get_device(),
218                                     {sycl::aspect::fp16});
219
220        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
221                                                   sycl::range<3>(1, 1, 64),
222                                               sycl::range<3>(1, 1, 64)),
223                             [=](sycl::nd_item<3> item_ct1) {
224                                 dequantize_block_q5_K(vx, y, item_ct1);
225                             });
226    }
227#else
228    {
229        dpct::has_capability_or_fail(stream->get_device(),
230                                     {sycl::aspect::fp16});
231
232        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
233                                                   sycl::range<3>(1, 1, 32),
234                                               sycl::range<3>(1, 1, 32)),
235                             [=](sycl::nd_item<3> item_ct1) {
236                                 dequantize_block_q5_K(vx, y, item_ct1);
237                             });
238    }
239
240#endif
241}
242
243template <typename dst_t>
244static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
245                                     dpct::queue_ptr stream) {
246    const int64_t nb = k / QK_K;
247#if QK_K == 256
248    {
249        dpct::has_capability_or_fail(stream->get_device(),
250                                     {sycl::aspect::fp16});
251
252        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
253                                                   sycl::range<3>(1, 1, 64),
254                                               sycl::range<3>(1, 1, 64)),
255                             [=](sycl::nd_item<3> item_ct1) {
256                                 dequantize_block_q6_K(vx, y, item_ct1);
257                             });
258    }
259#else
260    {
261        dpct::has_capability_or_fail(stream->get_device(),
262                                     {sycl::aspect::fp16});
263
264        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
265                                                   sycl::range<3>(1, 1, 32),
266                                               sycl::range<3>(1, 1, 32)),
267                             [=](sycl::nd_item<3> item_ct1) {
268                                 dequantize_block_q6_K(vx, y, item_ct1);
269                             });
270    }
271
272#endif
273}
274
275template <typename dst_t>
276static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
277    const int64_t nb = k / QK_K;
278
279    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
280
281    stream->parallel_for(
282        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
283        [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
284}
285
286template <typename dst_t>
287static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
288                                        dpct::queue_ptr stream) {
289    const int64_t nb = k / QK_K;
290    {
291        dpct::has_capability_or_fail(stream->get_device(),
292                                     {sycl::aspect::fp16});
293
294        stream->submit([&](sycl::handler &cgh) {
295            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
296                                                   sycl::range<3>(1, 1, 32),
297                                               sycl::range<3>(1, 1, 32)),
298                             [=](sycl::nd_item<3> item_ct1) {
299                                 dequantize_block_iq1_s(
300                                     vx, y, item_ct1, iq1s_grid_gpu
301                                     );
302                             });
303        });
304    }
305}
306
307template <typename dst_t>
308static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
309                                        dpct::queue_ptr stream) {
310    const int64_t nb = k / QK_K;
311    {
312        dpct::has_capability_or_fail(stream->get_device(),
313                                     {sycl::aspect::fp16});
314
315        stream->submit([&](sycl::handler &cgh) {
316            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
317                                                   sycl::range<3>(1, 1, 32),
318                                               sycl::range<3>(1, 1, 32)),
319                             [=](sycl::nd_item<3> item_ct1) {
320                                 dequantize_block_iq1_m(
321                                     vx, y, item_ct1, iq1s_grid_gpu
322                                     );
323                             });
324        });
325    }
326}
327
328template <typename dst_t>
329static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
330                                        dpct::queue_ptr stream) {
331    const int64_t nb = k / QK_K;
332    {
333        dpct::has_capability_or_fail(stream->get_device(),
334                                     {sycl::aspect::fp16});
335
336        stream->submit([&](sycl::handler &cgh) {
337            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
338                                                   sycl::range<3>(1, 1, 32),
339                                               sycl::range<3>(1, 1, 32)),
340                             [=](sycl::nd_item<3> item_ct1) {
341                                 dequantize_block_iq2_xxs(
342                                     vx, y, item_ct1, iq2xxs_grid,
343                                     ksigns_iq2xs, kmask_iq2xs);
344                             });
345        });
346    }
347}
348
349template <typename dst_t>
350static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
351                                       dpct::queue_ptr stream) {
352    const int64_t nb = k / QK_K;
353    {
354        dpct::has_capability_or_fail(stream->get_device(),
355                                     {sycl::aspect::fp16});
356
357        stream->submit([&](sycl::handler &cgh) {
358            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
359                                                   sycl::range<3>(1, 1, 32),
360                                               sycl::range<3>(1, 1, 32)),
361                             [=](sycl::nd_item<3> item_ct1) {
362                                 dequantize_block_iq2_xs(
363                                     vx, y, item_ct1, iq2xs_grid,
364                                     ksigns_iq2xs, kmask_iq2xs);
365                             });
366        });
367    }
368}
369
370template <typename dst_t>
371static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
372                                      dpct::queue_ptr stream) {
373    const int64_t nb = k / QK_K;
374    {
375        dpct::has_capability_or_fail(stream->get_device(),
376                                     {sycl::aspect::fp16});
377
378        stream->submit([&](sycl::handler &cgh) {
379            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
380                                                   sycl::range<3>(1, 1, 32),
381                                               sycl::range<3>(1, 1, 32)),
382                             [=](sycl::nd_item<3> item_ct1) {
383                                 dequantize_block_iq2_s(vx, y, item_ct1);
384                             });
385        });
386    }
387}
388
389
390template <typename dst_t>
391static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
392                                        dpct::queue_ptr stream) {
393    const int64_t nb = k / QK_K;
394    {
395        dpct::has_capability_or_fail(stream->get_device(),
396                                     {sycl::aspect::fp16});
397
398        stream->submit([&](sycl::handler &cgh) {
399            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
400                                                   sycl::range<3>(1, 1, 32),
401                                               sycl::range<3>(1, 1, 32)),
402                             [=](sycl::nd_item<3> item_ct1) {
403                                 dequantize_block_iq3_xxs(
404                                     vx, y, item_ct1, iq3xxs_grid,
405                                     ksigns_iq2xs, kmask_iq2xs);
406                             });
407        });
408    }
409}
410
411template <typename dst_t>
412static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
413                                        dpct::queue_ptr stream) {
414    const int64_t nb = k / QK_K;
415    {
416        dpct::has_capability_or_fail(stream->get_device(),
417                                     {sycl::aspect::fp16});
418
419        stream->submit([&](sycl::handler &cgh) {
420            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
421                                                   sycl::range<3>(1, 1, 32),
422                                               sycl::range<3>(1, 1, 32)),
423                             [=](sycl::nd_item<3> item_ct1) {
424                                 dequantize_block_iq3_s(
425                                     vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
426                             });
427        });
428    }
429}
430
431template <typename dst_t>
432static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
433                                       dpct::queue_ptr stream) {
434    const int64_t nb = (k + QK_K - 1) / QK_K;
435#if QK_K == 64
436    dequantize_row_iq4_nl_sycl(vx, y, k, stream);
437#else
438      {
439            dpct::has_capability_or_fail(stream->get_device(),
440                                         {sycl::aspect::fp16});
441
442            stream->submit([&](sycl::handler &cgh) {
443                  cgh.parallel_for(
444                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
445                                            sycl::range<3>(1, 1, 32),
446                                        sycl::range<3>(1, 1, 32)),
447                      [=](sycl::nd_item<3> item_ct1) {
448                            dequantize_block_iq4_xs(vx, y, item_ct1);
449                      });
450            });
451      }
452#endif
453}
454
455template <typename dst_t>
456static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
457                                       dpct::queue_ptr stream) {
458    const int64_t nb = (k + QK_K - 1) / QK_K;
459      {
460            dpct::has_capability_or_fail(stream->get_device(),
461                                         {sycl::aspect::fp16});
462
463            stream->submit([&](sycl::handler &cgh) {
464                  cgh.parallel_for(
465                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
466                                            sycl::range<3>(1, 1, 32),
467                                        sycl::range<3>(1, 1, 32)),
468                      [=](sycl::nd_item<3> item_ct1) {
469                            dequantize_block_iq4_nl(vx, y, item_ct1);
470                      });
471            });
472      }
473}
474
475template <typename dst_t>
476static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
477    const int nb = (k + QK_K - 1) / QK_K;
478    stream->parallel_for(
479        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
480        [=](sycl::nd_item<3> item_ct1) {
481            dequantize_block_mxfp4(vx, y, item_ct1);
482        });
483}
484
485template <typename src_t, typename dst_t>
486static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
487                          const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
488                          const sycl::nd_item<3> & item_ct1) {
489
490    const int64_t work_group_size = item_ct1.get_local_range(2);
491    const int64_t global_id       = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
492
493    const int64_t i01 = item_ct1.get_group(1);
494    const int64_t i02 = item_ct1.get_group(0) % ne02;
495    const int64_t i03 = item_ct1.get_group(0) / ne02;
496
497    // make each work-item deal with more elements since sycl global range can not exceed max int
498    const src_t * x = static_cast<const src_t *>(vx);
499    const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
500    const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
501
502#pragma unroll
503    for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
504        y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
505    }
506}
507
508template <typename src_t, typename dst_t>
509static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
510                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
511                                  const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
512    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
513
514    sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
515
516    // decrease global range when it exceeds the max int
517    // TODO: Downsample logic is separated from the kernel, a rewrite is desirable
518    int64_t        downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
519    sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
520
521    queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
522        convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
523    });
524}
525
526template <typename src_t, typename dst_t>
527static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
528    convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
529}
530
531
532to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
533    switch (type) {
534        case GGML_TYPE_Q4_0:
535            if (dst->src[0]->extra &&
536                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
537                return dequantize_row_q4_0_sycl_reorder;
538            } else {
539                return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
540            }
541        case GGML_TYPE_Q4_1:
542            return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
543        case GGML_TYPE_Q5_0:
544            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
545        case GGML_TYPE_Q5_1:
546            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
547        case GGML_TYPE_Q8_0:
548            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
549        case GGML_TYPE_Q2_K:
550            return dequantize_row_q2_K_sycl;
551        case GGML_TYPE_Q3_K:
552            return dequantize_row_q3_K_sycl;
553        case GGML_TYPE_Q4_K:
554            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
555                return dequantize_row_q4_K_sycl_reorder;
556            } else {
557                return dequantize_row_q4_K_sycl;
558            }
559        case GGML_TYPE_Q5_K:
560            return dequantize_row_q5_K_sycl;
561        case GGML_TYPE_Q6_K:
562            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
563                return dequantize_row_q6_K_sycl_reorder;
564            } else {
565                return dequantize_row_q6_K_sycl;
566            }
567        case GGML_TYPE_IQ1_S:
568            return dequantize_row_iq1_s_sycl;
569        case GGML_TYPE_IQ1_M:
570            return dequantize_row_iq1_m_sycl;
571        case GGML_TYPE_IQ2_XXS:
572            return dequantize_row_iq2_xxs_sycl;
573        case GGML_TYPE_IQ2_XS:
574            return dequantize_row_iq2_xs_sycl;
575        case GGML_TYPE_IQ2_S:
576            return dequantize_row_iq2_s_sycl;
577        case GGML_TYPE_IQ3_XXS:
578            return dequantize_row_iq3_xxs_sycl;
579        case GGML_TYPE_IQ3_S:
580            return dequantize_row_iq3_s_sycl;
581        case GGML_TYPE_IQ4_XS:
582            return dequantize_row_iq4_xs_sycl;
583        case GGML_TYPE_IQ4_NL:
584            return dequantize_row_iq4_nl_sycl;
585        case GGML_TYPE_MXFP4:
586            return dequantize_row_mxfp4_sycl;
587        case GGML_TYPE_F32:
588            return convert_unary_sycl<float>;
589#ifdef GGML_SYCL_HAS_BF16
590        case GGML_TYPE_BF16:
591            return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
592#endif
593        default:
594            return nullptr;
595    }
596}
597
598to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
599    switch (type) {
600        case GGML_TYPE_Q4_0:
601            if (dst->src[0]->extra &&
602                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
603                return dequantize_row_q4_0_sycl_reorder;
604            } else {
605                return dequantize_row_q4_0_sycl;
606            }
607        case GGML_TYPE_Q4_1:
608            return dequantize_row_q4_1_sycl;
609        case GGML_TYPE_Q5_0:
610            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
611        case GGML_TYPE_Q5_1:
612            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
613        case GGML_TYPE_Q8_0:
614            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
615        case GGML_TYPE_Q2_K:
616            return dequantize_row_q2_K_sycl;
617        case GGML_TYPE_Q3_K:
618            return dequantize_row_q3_K_sycl;
619        case GGML_TYPE_Q4_K:
620            if (dst->src[0]->extra &&
621                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
622                return dequantize_row_q4_K_sycl_reorder;
623            } else {
624                return dequantize_row_q4_K_sycl;
625            }
626        case GGML_TYPE_Q5_K:
627            return dequantize_row_q5_K_sycl;
628        case GGML_TYPE_Q6_K:
629            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
630                return dequantize_row_q6_K_sycl_reorder;
631            } else {
632                return dequantize_row_q6_K_sycl;
633            }
634        case GGML_TYPE_IQ1_S:
635            return dequantize_row_iq1_s_sycl;
636        case GGML_TYPE_IQ1_M:
637            return dequantize_row_iq1_m_sycl;
638        case GGML_TYPE_IQ2_XXS:
639            return dequantize_row_iq2_xxs_sycl;
640        case GGML_TYPE_IQ2_XS:
641            return dequantize_row_iq2_xs_sycl;
642        case GGML_TYPE_IQ2_S:
643            return dequantize_row_iq2_s_sycl;
644        case GGML_TYPE_IQ3_XXS:
645            return dequantize_row_iq3_xxs_sycl;
646        case GGML_TYPE_IQ3_S:
647            return dequantize_row_iq3_s_sycl;
648        case GGML_TYPE_IQ4_XS:
649            return dequantize_row_iq4_xs_sycl;
650        case GGML_TYPE_IQ4_NL:
651            return dequantize_row_iq4_nl_sycl;
652        case GGML_TYPE_MXFP4:
653            return dequantize_row_mxfp4_sycl;
654        case GGML_TYPE_F16:
655            return convert_unary_sycl<sycl::half>;
656#ifdef GGML_SYCL_HAS_BF16
657        case GGML_TYPE_BF16:
658            return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
659#endif
660        default:
661            return nullptr;
662    }
663}
664
665to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
666    switch (type) {
667        case GGML_TYPE_F32:
668            return convert_unary_nc_sycl<float>;
669#ifdef GGML_SYCL_HAS_BF16
670        case GGML_TYPE_BF16:
671            return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
672#endif
673        default:
674            return nullptr;
675    }
676}