1#include "mmvq.hpp"
   2
   3#include "ggml.h"
   4#include "common.hpp"
   5#include "quants.hpp"
   6#include "vecdotq.hpp"
   7
   8template <typename reorder_vec_dot_q_sycl>
   9static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  10                                  const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
  11    using block_type   = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
  12    using block_traits = typename block_type::traits;
  13
  14    const auto sg           = nd_item.get_sub_group();
  15    const int  sg_range     = sg.get_group_linear_range();
  16    const int  workgroup_id = nd_item.get_group_linear_id();
  17    const int  sg_id        = sg.get_group_linear_id();
  18    const int  row          = workgroup_id * sg_range + sg_id;
  19
  20    if (row >= nrows) {
  21        return;
  22    }
  23
  24    const int     blocks_per_row              = ncols / block_traits::qk;
  25    constexpr int blocks_per_subgroup         = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
  26    constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
  27    const int     nblocks                     = nrows * (ncols / block_traits::qk);
  28
  29    static_assert(blocks_per_subgroup > 0);
  30    static_assert(block_elements_per_subgroup > 0);
  31
  32    float partial_sum = 0.0f;
  33    for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
  34        const int ibx = row * blocks_per_row + i;  // x block index
  35
  36        const auto         bx_offset      = block_type::get_block_offset(ibx, nblocks);
  37        const auto         d_offset       = block_type::get_d_offset(nrows, ncols, ibx);
  38        // Y block index that aligns with ibx
  39        const int iby = i * block_type::block_to_q8_1_ratio();
  40        const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
  41        const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
  42
  43#pragma unroll
  44        for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
  45            // x block quant index when casting the quants to int
  46            const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
  47
  48            partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
  49        }
  50    }
  51
  52    auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
  53
  54    if (sg.leader()) {
  55        dst[row] = sum;
  56    }
  57}
  58
  59template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
  60static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  61                          const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
  62    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
  63
  64    if (row >= nrows) {
  65        return;
  66    }
  67
  68    const int     blocks_per_row  = ncols / qk;
  69    constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;  // Ensuring blocks_per_warp > 0
  70
  71    assert(blocks_per_warp > 0);
  72
  73    // partial sum for each thread
  74    float tmp = 0.0f;
  75
  76    const block_q_t *  x = (const block_q_t *) vx;
  77    const block_q8_1 * y = (const block_q8_1 *) vy;
  78
  79    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
  80        const int ibx = row * blocks_per_row + i;  // x block index
  81
  82        const int iby = i * (qk / QK8_1);          // y block index that aligns with ibx
  83
  84        for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
  85            const int iqs = elem + vdr * (item_ct1.get_local_id(2) %
  86                                          (qi / vdr));  // x block quant index when casting the quants to int
  87
  88            tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
  89        }
  90    }
  91
  92    // sum up partial sums and write back result
  93#pragma unroll
  94    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
  95        tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
  96    }
  97
  98    if (item_ct1.get_local_id(2) == 0) {
  99        dst[row] = tmp;
 100    }
 101}
 102
 103template <int qk, int qi, typename block_q_t, int vdr>
 104static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
 105                                       const void *__restrict__ vy,
 106                                       float *__restrict__ dst, const int ncols,
 107                                       const int nrows,
 108                                       const sycl::nd_item<3> &item_ct1) {
 109    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 110                    item_ct1.get_local_id(1);
 111
 112    if (row >= nrows) {
 113        return;
 114    }
 115
 116    const int blocks_per_row = ncols / qk;
 117    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 118    assert(blocks_per_warp>0);
 119
 120// partial sum for each thread
 121    float tmp = 0.0f;
 122
 123    const block_q_t  * x = (const block_q_t  *) vx;
 124    const block_q8_1 * y = (const block_q8_1 *) vy;
 125
 126    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 127         i += blocks_per_warp) {
 128        const int ibx = row*blocks_per_row + i; // x block index
 129
 130        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 131
 132        const int iqs =
 133            vdr *
 134            (item_ct1.get_local_id(2) %
 135             (qi / vdr)); // x block quant index when casting the quants to int
 136
 137        tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
 138    }
 139
 140    // sum up partial sums and write back result
 141#pragma unroll
 142    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 143        tmp +=
 144            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 145    }
 146
 147    if (item_ct1.get_local_id(2) == 0) {
 148        dst[row] = tmp;
 149    }
 150}
 151
 152template <int qk, int qi, typename block_q_t, int vdr>
 153static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
 154                                      const void *__restrict__ vy,
 155                                      float *__restrict__ dst, const int ncols,
 156                                      const int nrows,
 157                                      const sycl::nd_item<3> &item_ct1) {
 158    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 159                    item_ct1.get_local_id(1);
 160
 161    if (row >= nrows) {
 162        return;
 163    }
 164
 165    const int blocks_per_row = ncols / qk;
 166    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 167    assert(blocks_per_warp>0);
 168// partial sum for each thread
 169    float tmp = 0.0f;
 170
 171    const block_q_t  * x = (const block_q_t  *) vx;
 172    const block_q8_1 * y = (const block_q8_1 *) vy;
 173
 174    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 175         i += blocks_per_warp) {
 176        const int ibx = row*blocks_per_row + i; // x block index
 177
 178        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 179
 180        const int iqs =
 181            vdr *
 182            (item_ct1.get_local_id(2) %
 183             (qi / vdr)); // x block quant index when casting the quants to int
 184
 185        tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
 186    }
 187
 188    // sum up partial sums and write back result
 189#pragma unroll
 190    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 191        tmp +=
 192            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 193    }
 194
 195    if (item_ct1.get_local_id(2) == 0) {
 196        dst[row] = tmp;
 197    }
 198}
 199
 200template <int qk, int qi, typename block_q_t, int vdr>
 201static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
 202                                     const void *__restrict__ vy,
 203                                     float *__restrict__ dst, const int ncols,
 204                                     const int nrows,
 205                                     const sycl::nd_item<3> &item_ct1) {
 206    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 207                    item_ct1.get_local_id(1);
 208
 209    if (row >= nrows) {
 210        return;
 211    }
 212
 213    const int blocks_per_row = ncols / qk;
 214    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 215    assert(blocks_per_warp>0);
 216// partial sum for each thread
 217    float tmp = 0.0f;
 218
 219    const block_q_t  * x = (const block_q_t  *) vx;
 220    const block_q8_1 * y = (const block_q8_1 *) vy;
 221
 222    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 223         i += blocks_per_warp) {
 224        const int ibx = row*blocks_per_row + i; // x block index
 225
 226        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 227
 228        const int iqs =
 229            vdr *
 230            (item_ct1.get_local_id(2) %
 231             (qi / vdr)); // x block quant index when casting the quants to int
 232
 233        tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
 234    }
 235
 236    // sum up partial sums and write back result
 237#pragma unroll
 238    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 239        tmp +=
 240            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 241    }
 242
 243    if (item_ct1.get_local_id(2) == 0) {
 244        dst[row] = tmp;
 245    }
 246}
 247
 248template <int qk, int qi, typename block_q_t, int vdr>
 249static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
 250                                       const void *__restrict__ vy,
 251                                       float *__restrict__ dst, const int ncols,
 252                                       const int nrows,
 253                                       const sycl::nd_item<3> &item_ct1) {
 254    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 255                    item_ct1.get_local_id(1);
 256
 257    if (row >= nrows) {
 258        return;
 259    }
 260
 261    const int blocks_per_row = ncols / qk;
 262    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 263    assert(blocks_per_warp>0);
 264// partial sum for each thread
 265    float tmp = 0.0f;
 266
 267    const block_q_t  * x = (const block_q_t  *) vx;
 268    const block_q8_1 * y = (const block_q8_1 *) vy;
 269
 270    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 271         i += blocks_per_warp) {
 272        const int ibx = row*blocks_per_row + i; // x block index
 273
 274        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 275
 276        const int iqs =
 277            vdr *
 278            (item_ct1.get_local_id(2) %
 279             (qi / vdr)); // x block quant index when casting the quants to int
 280
 281        tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
 282    }
 283
 284    // sum up partial sums and write back result
 285#pragma unroll
 286    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 287        tmp +=
 288            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 289    }
 290
 291    if (item_ct1.get_local_id(2) == 0) {
 292        dst[row] = tmp;
 293    }
 294}
 295
 296template <int qk, int qi, typename block_q_t, int vdr>
 297static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
 298                                     const void *__restrict__ vy,
 299                                     float *__restrict__ dst, const int ncols,
 300                                     const int nrows,
 301                                     const sycl::nd_item<3> &item_ct1) {
 302    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 303                    item_ct1.get_local_id(1);
 304
 305    if (row >= nrows) {
 306        return;
 307    }
 308
 309    const int blocks_per_row = ncols / qk;
 310    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 311    assert(blocks_per_warp>0);
 312// partial sum for each thread
 313    float tmp = 0.0f;
 314
 315    const block_q_t  * x = (const block_q_t  *) vx;
 316    const block_q8_1 * y = (const block_q8_1 *) vy;
 317
 318    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 319         i += blocks_per_warp) {
 320        const int ibx = row*blocks_per_row + i; // x block index
 321
 322        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 323
 324        const int iqs =
 325            vdr *
 326            (item_ct1.get_local_id(2) %
 327             (qi / vdr)); // x block quant index when casting the quants to int
 328
 329        tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
 330    }
 331
 332    // sum up partial sums and write back result
 333#pragma unroll
 334    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 335        tmp +=
 336            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 337    }
 338
 339    if (item_ct1.get_local_id(2) == 0) {
 340        dst[row] = tmp;
 341    }
 342}
 343
 344template <int qk, int qi, typename block_q_t, int vdr>
 345static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
 346                                     const void *__restrict__ vy,
 347                                     float *__restrict__ dst, const int ncols,
 348                                     const int nrows,
 349                                     const sycl::nd_item<3> &item_ct1) {
 350    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 351                    item_ct1.get_local_id(1);
 352
 353    if (row >= nrows) {
 354        return;
 355    }
 356
 357    const int blocks_per_row = ncols / qk;
 358    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 359    assert(blocks_per_warp>0);
 360// partial sum for each thread
 361    float tmp = 0.0f;
 362
 363    const block_q_t  * x = (const block_q_t  *) vx;
 364    const block_q8_1 * y = (const block_q8_1 *) vy;
 365
 366    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 367         i += blocks_per_warp) {
 368        const int ibx = row*blocks_per_row + i; // x block index
 369
 370        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 371
 372        const int iqs =
 373            vdr *
 374            (item_ct1.get_local_id(2) %
 375             (qi / vdr)); // x block quant index when casting the quants to int
 376
 377        tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
 378    }
 379
 380    // sum up partial sums and write back result
 381#pragma unroll
 382    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 383        tmp +=
 384            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 385    }
 386
 387    if (item_ct1.get_local_id(2) == 0) {
 388        dst[row] = tmp;
 389    }
 390}
 391
 392template <int qk, int qi, typename block_q_t, int vdr>
 393static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
 394                                     const void *__restrict__ vy,
 395                                     float *__restrict__ dst, const int ncols,
 396                                     const int nrows,
 397                                     const sycl::nd_item<3> &item_ct1) {
 398    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 399                    item_ct1.get_local_id(1);
 400
 401    if (row >= nrows) {
 402        return;
 403    }
 404
 405    const int blocks_per_row = ncols / qk;
 406    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 407    assert(blocks_per_warp>0);
 408// partial sum for each thread
 409    float tmp = 0.0f;
 410
 411    const block_q_t  * x = (const block_q_t  *) vx;
 412    const block_q8_1 * y = (const block_q8_1 *) vy;
 413
 414    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 415         i += blocks_per_warp) {
 416        const int ibx = row*blocks_per_row + i; // x block index
 417
 418        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 419
 420        const int iqs =
 421            vdr *
 422            (item_ct1.get_local_id(2) %
 423             (qi / vdr)); // x block quant index when casting the quants to int
 424
 425        tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
 426    }
 427
 428    // sum up partial sums and write back result
 429#pragma unroll
 430    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 431        tmp +=
 432            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 433    }
 434
 435    if (item_ct1.get_local_id(2) == 0) {
 436        dst[row] = tmp;
 437    }
 438}
 439
 440template <int qk, int qi, typename block_q_t, int vdr>
 441static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
 442                                      const void *__restrict__ vy,
 443                                      float *__restrict__ dst, const int ncols,
 444                                      const int nrows,
 445                                      const sycl::nd_item<3> &item_ct1) {
 446    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 447                    item_ct1.get_local_id(1);
 448
 449    if (row >= nrows) {
 450        return;
 451    }
 452
 453    const int blocks_per_row = ncols / qk;
 454    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 455    assert(blocks_per_warp>0);
 456// partial sum for each thread
 457    float tmp = 0.0f;
 458
 459    const block_q_t  * x = (const block_q_t  *) vx;
 460    const block_q8_1 * y = (const block_q8_1 *) vy;
 461
 462    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 463         i += blocks_per_warp) {
 464        const int ibx = row*blocks_per_row + i; // x block index
 465
 466        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 467
 468        const int iqs =
 469            vdr *
 470            (item_ct1.get_local_id(2) %
 471             (qi / vdr)); // x block quant index when casting the quants to int
 472
 473        tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
 474    }
 475
 476    // sum up partial sums and write back result
 477#pragma unroll
 478    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 479        tmp +=
 480            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 481    }
 482
 483    if (item_ct1.get_local_id(2) == 0) {
 484        dst[row] = tmp;
 485    }
 486}
 487
 488
 489template <int qk, int qi, typename block_q_t, int vdr>
 490static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
 491                                      const void *__restrict__ vy,
 492                                      float *__restrict__ dst, const int ncols,
 493                                      const int nrows,
 494                                      const sycl::nd_item<3> &item_ct1) {
 495    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 496                    item_ct1.get_local_id(1);
 497
 498    if (row >= nrows) {
 499        return;
 500    }
 501
 502    const int blocks_per_row = ncols / qk;
 503    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 504    assert(blocks_per_warp>0);
 505// partial sum for each thread
 506    float tmp = 0.0f;
 507
 508    const block_q_t  * x = (const block_q_t  *) vx;
 509    const block_q8_1 * y = (const block_q8_1 *) vy;
 510
 511    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
 512         i += blocks_per_warp) {
 513        const int ibx = row*blocks_per_row + i; // x block index
 514
 515        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 516
 517        const int iqs =
 518            vdr *
 519            (item_ct1.get_local_id(2) %
 520             (qi / vdr)); // x block quant index when casting the quants to int
 521
 522        tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
 523    }
 524
 525    // sum up partial sums and write back result
 526#pragma unroll
 527    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
 528        tmp +=
 529            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 530    }
 531
 532    if (item_ct1.get_local_id(2) == 0) {
 533        dst[row] = tmp;
 534    }
 535}
 536
 537static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
 538                                                    const int nrows, dpct::queue_ptr stream) {
 539    GGML_ASSERT(ncols % QK4_0 == 0);
 540    const int        block_num_y   = ceil_div(nrows, GGML_SYCL_MMV_Y);
 541    constexpr size_t num_subgroups = 16;
 542    GGML_ASSERT(block_num_y % num_subgroups == 0);
 543
 544    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
 545    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
 546
 547    stream->submit([&](sycl::handler & cgh) {
 548        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
 549                         [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 550                             mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
 551                                                                                           nd_item);
 552                         });
 553    });
 554}
 555
 556static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
 557                                       dpct::queue_ptr stream) {
 558    GGML_ASSERT(ncols % QK4_0 == 0);
 559    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 560    const sycl::range<3> block_nums(1, 1, block_num_y);
 561    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 562
 563    {
 564        stream->submit([&](sycl::handler & cgh) {
 565            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
 566                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 567                                 mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
 568                                     vx, vy, dst, ncols, nrows, item_ct1);
 569                             });
 570        });
 571    }
 572}
 573
 574static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
 575                                       float *dst, const int ncols,
 576                                       const int nrows,
 577                                       dpct::queue_ptr stream) {
 578    GGML_ASSERT(ncols % QK4_1 == 0);
 579    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 580    const sycl::range<3> block_nums(1, 1, block_num_y);
 581    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 582    {
 583
 584        stream->submit([&](sycl::handler &cgh) {
 585
 586            cgh.parallel_for(
 587                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 588                [=](sycl::nd_item<3> item_ct1)
 589                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 590                        mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
 591                                      VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
 592                            vx, vy, dst, ncols, nrows, item_ct1);
 593                    });
 594        });
 595    }
 596}
 597
 598static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
 599                                        dpct::queue_ptr stream) {
 600    GGML_ASSERT(ncols % QK_MXFP4 == 0);
 601    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 602    const sycl::range<3> block_nums(1, 1, block_num_y);
 603    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 604
 605    {
 606        stream->submit([&](sycl::handler & cgh) {
 607            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
 608                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 609                                 mul_mat_vec_q<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
 610                                     vx, vy, dst, ncols, nrows, item_ct1);
 611                             });
 612        });
 613    }
 614}
 615
 616
 617static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
 618                                       float *dst, const int ncols,
 619                                       const int nrows,
 620                                       dpct::queue_ptr stream) {
 621    GGML_ASSERT(ncols % QK5_0 == 0);
 622    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 623    const sycl::range<3> block_nums(1, 1, block_num_y);
 624    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 625    {
 626
 627        stream->submit([&](sycl::handler &cgh) {
 628
 629            cgh.parallel_for(
 630                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 631                [=](sycl::nd_item<3> item_ct1)
 632                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 633                        mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
 634                                      VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
 635                            vx, vy, dst, ncols, nrows, item_ct1);
 636                    });
 637        });
 638    }
 639}
 640
 641static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
 642                                       float *dst, const int ncols,
 643                                       const int nrows,
 644                                       dpct::queue_ptr stream) {
 645    GGML_ASSERT(ncols % QK5_1 == 0);
 646    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 647    const sycl::range<3> block_nums(1, 1, block_num_y);
 648    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 649    {
 650
 651        stream->submit([&](sycl::handler &cgh) {
 652
 653            cgh.parallel_for(
 654                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 655                [=](sycl::nd_item<3> item_ct1)
 656                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 657                        mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
 658                                      VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
 659                            vx, vy, dst, ncols, nrows, item_ct1);
 660                    });
 661        });
 662    }
 663}
 664
 665static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
 666                                       float *dst, const int ncols,
 667                                       const int nrows,
 668                                       dpct::queue_ptr stream) {
 669    GGML_ASSERT(ncols % QK8_0 == 0);
 670    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 671    const sycl::range<3> block_nums(1, 1, block_num_y);
 672    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 673    {
 674
 675        stream->submit([&](sycl::handler &cgh) {
 676
 677            cgh.parallel_for(
 678                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 679                [=](sycl::nd_item<3> item_ct1)
 680                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 681                        mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
 682                                      VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
 683                            vx, vy, dst, ncols, nrows, item_ct1);
 684                    });
 685        });
 686    }
 687}
 688
 689static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
 690                                       float *dst, const int ncols,
 691                                       const int nrows,
 692                                       dpct::queue_ptr stream) {
 693    GGML_ASSERT(ncols % QK_K == 0);
 694    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 695    const sycl::range<3> block_nums(1, 1, block_num_y);
 696    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 697    {
 698
 699        stream->submit([&](sycl::handler &cgh) {
 700
 701            cgh.parallel_for(
 702                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 703                [=](sycl::nd_item<3> item_ct1)
 704                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 705                        mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
 706                                      VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
 707                            vx, vy, dst, ncols, nrows, item_ct1);
 708                    });
 709        });
 710    }
 711}
 712
 713static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
 714                                       float *dst, const int ncols,
 715                                       const int nrows,
 716                                       dpct::queue_ptr stream) {
 717    GGML_ASSERT(ncols % QK_K == 0);
 718    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 719    const sycl::range<3> block_nums(1, 1, block_num_y);
 720    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 721    {
 722
 723        stream->submit([&](sycl::handler &cgh) {
 724
 725            cgh.parallel_for(
 726                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 727                [=](sycl::nd_item<3> item_ct1)
 728                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 729                        mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
 730                                      VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
 731                            vx, vy, dst, ncols, nrows, item_ct1);
 732                    });
 733        });
 734    }
 735}
 736
 737static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
 738                                       float *dst, const int ncols,
 739                                       const int nrows,
 740                                       dpct::queue_ptr stream) {
 741    GGML_ASSERT(ncols % QK_K == 0);
 742    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 743    const sycl::range<3> block_nums(1, 1, block_num_y);
 744    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 745    {
 746
 747        stream->submit([&](sycl::handler &cgh) {
 748
 749            cgh.parallel_for(
 750                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 751                [=](sycl::nd_item<3> item_ct1)
 752                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 753                        mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
 754                                      VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
 755                            vx, vy, dst, ncols, nrows, item_ct1);
 756                    });
 757        });
 758    }
 759}
 760
 761static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
 762    const int nrows, dpct::queue_ptr stream) {
 763    GGML_ASSERT(ncols % QK_K == 0);
 764
 765    const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
 766    constexpr size_t num_subgroups = 16;
 767    GGML_ASSERT(block_num_y % num_subgroups == 0);
 768
 769    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
 770    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
 771
 772    stream->submit([&](sycl::handler & cgh) {
 773        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
 774                            [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 775                                mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
 776                                                                                            nrows, nd_item);
 777                            });
 778    });
 779}
 780
 781
 782static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
 783                                       float *dst, const int ncols,
 784                                       const int nrows,
 785                                       dpct::queue_ptr stream) {
 786    GGML_ASSERT(ncols % QK_K == 0);
 787    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 788    const sycl::range<3> block_nums(1, 1, block_num_y);
 789    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 790    {
 791
 792        stream->submit([&](sycl::handler &cgh) {
 793
 794            cgh.parallel_for(
 795                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 796                [=](sycl::nd_item<3> item_ct1)
 797                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 798                        mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
 799                                      VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
 800                            vx, vy, dst, ncols, nrows, item_ct1);
 801                    });
 802        });
 803    }
 804}
 805
 806static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
 807                                               const int nrows, dpct::queue_ptr stream) {
 808    GGML_ASSERT(ncols % QK_K == 0);
 809    const int        block_num_y   = ceil_div(nrows, GGML_SYCL_MMV_Y);
 810    constexpr size_t num_subgroups = 16;
 811    GGML_ASSERT(block_num_y % num_subgroups == 0);
 812
 813    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
 814    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
 815
 816    stream->submit([&](sycl::handler & cgh) {
 817        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
 818                         [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 819                             mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
 820                                                                                           nd_item);
 821                         });
 822    });
 823}
 824static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
 825                                       float *dst, const int ncols,
 826                                       const int nrows,
 827                                       dpct::queue_ptr stream) {
 828    GGML_ASSERT(ncols % QK_K == 0);
 829    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 830    const sycl::range<3> block_nums(1, 1, block_num_y);
 831    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 832    {
 833
 834        stream->submit([&](sycl::handler &cgh) {
 835
 836            cgh.parallel_for(
 837                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 838                [=](sycl::nd_item<3> item_ct1)
 839                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 840                        mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
 841                                      VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
 842                            vx, vy, dst, ncols, nrows, item_ct1);
 843                    });
 844        });
 845    }
 846}
 847
 848
 849static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
 850                                          float *dst, const int ncols,
 851                                          const int nrows,
 852                                          dpct::queue_ptr stream) {
 853    GGML_ASSERT(ncols % QK_K == 0);
 854    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 855    const sycl::range<3> block_nums(1, 1, block_num_y);
 856    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 857    {
 858        stream->submit([&](sycl::handler &cgh) {
 859            cgh.parallel_for(
 860                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 861                [=](sycl::nd_item<3> item_ct1)
 862                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 863                        mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
 864                            vx, vy, dst, ncols, nrows, item_ct1);
 865                    });
 866        });
 867    }
 868}
 869
 870static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
 871                                         float *dst, const int ncols,
 872                                         const int nrows,
 873                                         dpct::queue_ptr stream) {
 874    GGML_ASSERT(ncols % QK_K == 0);
 875    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 876    const sycl::range<3> block_nums(1, 1, block_num_y);
 877    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 878    {
 879        stream->submit([&](sycl::handler & cgh) {
 880            cgh.parallel_for(
 881                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 882                [=](sycl::nd_item<3> item_ct1)
 883                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 884                        mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
 885                            vx, vy, dst, ncols, nrows, item_ct1);
 886                    });
 887        });
 888    }
 889}
 890
 891static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
 892                                         float *dst, const int ncols,
 893                                         const int nrows,
 894                                         dpct::queue_ptr stream) {
 895    GGML_ASSERT(ncols % QK_K == 0);
 896    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 897    const sycl::range<3> block_nums(1, 1, block_num_y);
 898    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 899    {
 900
 901        stream->submit([&](sycl::handler &cgh) {
 902            cgh.parallel_for(
 903                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 904                [=](sycl::nd_item<3> item_ct1)
 905                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 906                        mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
 907                            vx, vy, dst, ncols, nrows, item_ct1);
 908                    });
 909        });
 910    }
 911}
 912
 913static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
 914                                          float *dst, const int ncols,
 915                                          const int nrows,
 916                                          dpct::queue_ptr stream) {
 917    GGML_ASSERT(ncols % QK_K == 0);
 918    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 919    const sycl::range<3> block_nums(1, 1, block_num_y);
 920    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 921    {
 922
 923        stream->submit([&](sycl::handler &cgh) {
 924            cgh.parallel_for(
 925                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 926                [=](sycl::nd_item<3> item_ct1)
 927                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 928                        mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
 929                            vx, vy, dst, ncols, nrows, item_ct1);
 930                    });
 931        });
 932    }
 933}
 934
 935static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
 936                                          float *dst, const int ncols,
 937                                          const int nrows,
 938                                          dpct::queue_ptr stream) {
 939    GGML_ASSERT(ncols % QK_K == 0);
 940    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 941    const sycl::range<3> block_nums(1, 1, block_num_y);
 942    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 943    {
 944
 945        stream->submit([&](sycl::handler &cgh) {
 946            cgh.parallel_for(
 947                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 948                [=](sycl::nd_item<3> item_ct1)
 949                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 950                        mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
 951                            vx, vy, dst, ncols, nrows, item_ct1);
 952                    });
 953        });
 954    }
 955}
 956
 957static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
 958                                          float *dst, const int ncols,
 959                                          const int nrows,
 960                                          dpct::queue_ptr stream) {
 961    GGML_ASSERT(ncols % QK_K == 0);
 962    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 963    const sycl::range<3> block_nums(1, 1, block_num_y);
 964    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 965    {
 966
 967        stream->submit([&](sycl::handler &cgh) {
 968            cgh.parallel_for(
 969                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 970                [=](sycl::nd_item<3> item_ct1)
 971                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 972                        mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
 973                            vx, vy, dst, ncols, nrows, item_ct1);
 974                    });
 975        });
 976    }
 977}
 978
 979static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
 980                                          float *dst, const int ncols,
 981                                          const int nrows,
 982                                          dpct::queue_ptr stream) {
 983    GGML_ASSERT(ncols % QK_K == 0);
 984    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 985    const sycl::range<3> block_nums(1, 1, block_num_y);
 986    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 987    {
 988        stream->submit([&](sycl::handler &cgh) {
 989            cgh.parallel_for(
 990                sycl::nd_range<3>(block_nums * block_dims, block_dims),
 991                [=](sycl::nd_item<3> item_ct1)
 992                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 993                        mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
 994                            vx, vy, dst, ncols, nrows, item_ct1);
 995                    });
 996        });
 997    }
 998}
 999
1000static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
1001                                          float *dst, const int ncols,
1002                                          const int nrows,
1003                                          dpct::queue_ptr stream) {
1004    GGML_ASSERT(ncols % QK4_NL == 0);
1005    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1006    const sycl::range<3> block_nums(1, 1, block_num_y);
1007    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1008    {
1009
1010        stream->submit([&](sycl::handler &cgh) {
1011            cgh.parallel_for(
1012                sycl::nd_range<3>(block_nums * block_dims, block_dims),
1013                [=](sycl::nd_item<3> item_ct1)
1014                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1015                        mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
1016                            vx, vy, dst, ncols, nrows, item_ct1);
1017                    });
1018        });
1019    }
1020}
1021
1022static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
1023                                          float *dst, const int ncols,
1024                                          const int nrows,
1025                                          dpct::queue_ptr stream) {
1026    GGML_ASSERT(ncols % QK_K == 0);
1027    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1028    const sycl::range<3> block_nums(1, 1, block_num_y);
1029    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1030    {
1031
1032        stream->submit([&](sycl::handler &cgh) {
1033            cgh.parallel_for(
1034                sycl::nd_range<3>(block_nums * block_dims, block_dims),
1035                [=](sycl::nd_item<3> item_ct1)
1036                    [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1037                        mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
1038                            vx, vy, dst, ncols, nrows, item_ct1);
1039                    });
1040        });
1041    }
1042}
1043
1044void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1045                                ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1046                                const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
1047                                const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
1048                                const dpct::queue_ptr & stream) {
1049    const int64_t ne10 = src1->ne[0];
1050    GGML_ASSERT(ne10 % QK8_1 == 0);
1051
1052    const int64_t ne00     = src0->ne[0];
1053    const int64_t row_diff = row_high - row_low;
1054
1055    int id;
1056    SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
1057    const size_t q8_1_ts = sizeof(block_q8_1);
1058    const size_t q8_1_bs = QK8_1;
1059    // the main device has a larger memory buffer to hold the results from all GPUs
1060    // nrows_dst == nrows of the matrix that the kernel writes into
1061
1062    for (int i = 0; i < src1_ncols; i++) {
1063        const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
1064        const char * src1_ddq_i_bs     = src1_ddq_i + src1_ddq_i_offset;
1065        float *      dst_dd_i_bs       = dst_dd_i + i * dst->ne[0];
1066        switch (src0->type) {
1067            case GGML_TYPE_Q4_0:
1068                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1069                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1070                    GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
1071                    reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1072                } else {
1073                    GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
1074                    mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1075                }
1076                break;
1077            case GGML_TYPE_Q4_1:
1078                mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1079                break;
1080            case GGML_TYPE_Q5_0:
1081                mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1082                break;
1083            case GGML_TYPE_Q5_1:
1084                mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1085                break;
1086            case GGML_TYPE_Q8_0:
1087                mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1088                break;
1089            case GGML_TYPE_Q2_K:
1090                mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1091                break;
1092            case GGML_TYPE_Q3_K:
1093                mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1094                break;
1095            case GGML_TYPE_Q4_K:
1096                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1097                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1098                    GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
1099                    reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1100                } else {
1101                    GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n");
1102                    mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1103                }
1104                break;
1105            case GGML_TYPE_Q5_K:
1106                mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1107                break;
1108            case GGML_TYPE_Q6_K:
1109                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1110                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1111                    GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1112                    reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1113                } else {
1114                    GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1115                    mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1116                }
1117                break;
1118            case GGML_TYPE_IQ1_S:
1119                mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1120                break;
1121            case GGML_TYPE_IQ1_M:
1122                mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1123                break;
1124            case GGML_TYPE_IQ2_XXS:
1125                mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1126                break;
1127            case GGML_TYPE_IQ2_XS:
1128                mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1129                break;
1130            case GGML_TYPE_IQ2_S:
1131                mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1132                break;
1133            case GGML_TYPE_IQ3_XXS:
1134                mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1135                break;
1136            case GGML_TYPE_IQ3_S:
1137                mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1138                break;
1139            case GGML_TYPE_IQ4_NL:
1140                mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1141                break;
1142            case GGML_TYPE_IQ4_XS:
1143                mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1144                break;
1145            case GGML_TYPE_MXFP4:
1146                mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1147                break;
1148            default:
1149                GGML_ABORT("fatal error");
1150        }
1151    }
1152    GGML_UNUSED(src1);
1153    GGML_UNUSED(dst);
1154    GGML_UNUSED(src1_ddf_i);
1155    GGML_UNUSED(ctx);
1156}