1//
   2// MIT license
   3// Copyright (C) 2024 Intel Corporation
   4// SPDX-License-Identifier: MIT
   5//
   6
   7//
   8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
   9// See https://llvm.org/LICENSE.txt for license information.
  10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11//
  12
  13#include "mmq.hpp"
  14#include "vecdotq.hpp"
  15
  16typedef void (*allocate_tiles_sycl_t)(
  17    int** x_ql,
  18    sycl::half2** x_dm,
  19    int** x_qh,
  20    int** x_sc);
  21typedef void (*load_tiles_sycl_t)(
  22    const void* __restrict__ vx,
  23    int* __restrict__ x_ql,
  24    sycl::half2* __restrict__ x_dm,
  25    int* __restrict__ x_qh,
  26    int* __restrict__ x_sc,
  27    const int& i_offset,
  28    const int& i_max,
  29    const int& k,
  30    const int& blocks_per_row);
  31typedef float (*vec_dot_q_mul_mat_sycl_t)(
  32    const int* __restrict__ x_ql,
  33    const sycl::half2* __restrict__ x_dm,
  34    const int* __restrict__ x_qh,
  35    const int* __restrict__ x_sc,
  36    const int* __restrict__ y_qs,
  37    const sycl::half2* __restrict__ y_ms,
  38    const int& i,
  39    const int& j,
  40    const int& k);
  41
  42
  43template <int mmq_y>
  44static __dpct_inline__ void
  45allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  46                    int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
  47    (void)x_qh; (void)x_sc;
  48
  49    *x_ql = tile_x_qs_q4_0;
  50    *x_dm = (sycl::half2 *)tile_x_d_q4_0;
  51}
  52
  53template <int mmq_y, int nwarps, bool need_check>
  54static __dpct_inline__ void
  55load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
  56                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  57                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  58                const int &k, const int &blocks_per_row) {
  59    (void)x_qh; (void)x_sc;
  60    GGML_SYCL_ASSUME(i_offset >= 0);
  61    GGML_SYCL_ASSUME(i_offset <  nwarps);
  62    GGML_SYCL_ASSUME(k >= 0);
  63    GGML_SYCL_ASSUME(k <  WARP_SIZE);
  64
  65    const int kbx  = k / QI4_0;
  66    const int kqsx = k % QI4_0;
  67
  68    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
  69
  70    float * x_dmf = (float *) x_dm;
  71
  72#pragma unroll
  73    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  74        int i = i0 + i_offset;
  75
  76        if (need_check) {
  77            i = sycl::min(i, i_max);
  78        }
  79
  80        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
  81
  82        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  83        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
  84    }
  85
  86    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  87    const int kbxd = k % blocks_per_tile_x_row;
  88
  89#pragma unroll
  90    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  91        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
  92
  93        if (need_check) {
  94            i = sycl::min(i, i_max);
  95        }
  96
  97        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  98
  99        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
 100    }
 101}
 102
 103static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
 104    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 105    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 106    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 107    const int &i, const int &j, const int &k) {
 108    (void)x_qh; (void)x_sc;
 109
 110    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 111    const float * x_dmf = (const float *) x_dm;
 112
 113    int u[2*VDR_Q4_0_Q8_1_MMQ];
 114
 115#pragma unroll
 116    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
 117        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
 118        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
 119    }
 120
 121    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
 122        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
 123         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 124}
 125
 126template <int mmq_y>
 127static __dpct_inline__ void
 128allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 129                    int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
 130    (void)x_qh; (void)x_sc;
 131
 132    *x_ql = tile_x_qs_q4_1;
 133    *x_dm = tile_x_dm_q4_1;
 134}
 135
 136
 137template <int mmq_y, int nwarps, bool need_check>
 138static __dpct_inline__ void
 139load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
 140                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 141                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 142                const int &k, const int &blocks_per_row) {
 143    (void)x_qh; (void)x_sc;
 144
 145    GGML_SYCL_ASSUME(i_offset >= 0);
 146    GGML_SYCL_ASSUME(i_offset <  nwarps);
 147    GGML_SYCL_ASSUME(k >= 0);
 148    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 149
 150    const int kbx  = k / QI4_1;
 151    const int kqsx = k % QI4_1;
 152
 153    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
 154
 155#pragma unroll
 156    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 157        int i = i0 + i_offset;
 158
 159        if (need_check) {
 160            i = sycl::min(i, i_max);
 161        }
 162
 163        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
 164
 165        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
 166    }
 167
 168    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
 169    const int kbxd = k % blocks_per_tile_x_row;
 170
 171#pragma unroll
 172    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
 173        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
 174
 175        if (need_check) {
 176            i = sycl::min(i, i_max);
 177        }
 178
 179        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
 180
 181        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
 182    }
 183}
 184
 185static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
 186    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 187    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 188    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 189    const int &i, const int &j, const int &k) {
 190    (void)x_qh; (void)x_sc;
 191
 192    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 193
 194    int u[2*VDR_Q4_1_Q8_1_MMQ];
 195
 196#pragma unroll
 197    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
 198        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
 199        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
 200    }
 201
 202    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
 203        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
 204         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 205}
 206
 207template <int mmq_y>
 208static __dpct_inline__ void
 209allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 210                    int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
 211    (void)x_qh; (void)x_sc;
 212
 213    *x_ql = tile_x_ql_q5_0;
 214    *x_dm = (sycl::half2 *)tile_x_d_q5_0;
 215}
 216
 217template <int mmq_y, int nwarps, bool need_check>
 218static __dpct_inline__ void
 219load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
 220                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 221                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 222                const int &k, const int &blocks_per_row) {
 223    (void)x_qh; (void)x_sc;
 224
 225    GGML_SYCL_ASSUME(i_offset >= 0);
 226    GGML_SYCL_ASSUME(i_offset <  nwarps);
 227    GGML_SYCL_ASSUME(k >= 0);
 228    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 229
 230    const int kbx  = k / QI5_0;
 231    const int kqsx = k % QI5_0;
 232
 233    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
 234
 235#pragma unroll
 236    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 237        int i = i0 + i_offset;
 238
 239        if (need_check) {
 240            i = sycl::min(i, i_max);
 241        }
 242
 243        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
 244
 245        const int ql = get_int_from_uint8(bxi->qs, kqsx);
 246        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
 247
 248        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
 249        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
 250        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
 251        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
 252        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
 253        qs0 = dpct::vectorized_binary<sycl::char4>(
 254            qs0, 0x10101010, dpct::sub_sat()); // subtract 16
 255
 256        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
 257
 258        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
 259        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
 260        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
 261        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
 262        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
 263        qs1 = dpct::vectorized_binary<sycl::char4>(
 264            qs1, 0x10101010, dpct::sub_sat()); // subtract 16
 265
 266        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
 267    }
 268
 269    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
 270    const int kbxd = k % blocks_per_tile_x_row;
 271    float * x_dmf = (float *) x_dm;
 272
 273#pragma unroll
 274    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
 275        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
 276
 277        if (need_check) {
 278            i = sycl::min(i, i_max);
 279        }
 280
 281        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 282
 283        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
 284    }
 285}
 286
 287static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
 288    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 289    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 290    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 291    const int &i, const int &j, const int &k) {
 292    (void)x_qh; (void)x_sc;
 293
 294    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 295    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
 296    const float * x_dmf = (const float *) x_dm;
 297    const float * y_df  = (const float *) y_ds;
 298
 299    int u[2*VDR_Q5_0_Q8_1_MMQ];
 300
 301#pragma unroll
 302    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
 303        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
 304        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
 305    }
 306
 307    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
 308        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 309}
 310
 311template <int mmq_y>
 312static __dpct_inline__ void
 313allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 314                    int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
 315    (void)x_qh; (void)x_sc;
 316
 317    *x_ql = tile_x_ql_q5_1;
 318    *x_dm = tile_x_dm_q5_1;
 319}
 320
 321template <int mmq_y, int nwarps, bool need_check>
 322static __dpct_inline__ void
 323load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
 324                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 325                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 326                const int &k, const int &blocks_per_row) {
 327    (void)x_qh; (void)x_sc;
 328
 329    GGML_SYCL_ASSUME(i_offset >= 0);
 330    GGML_SYCL_ASSUME(i_offset < nwarps);
 331    GGML_SYCL_ASSUME(k >= 0);
 332    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 333
 334    const int kbx  = k / QI5_1;
 335    const int kqsx = k % QI5_1;
 336
 337    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
 338
 339#pragma unroll
 340    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 341        int i = i0 + i_offset;
 342
 343        if (need_check) {
 344            i = sycl::min(i, i_max);
 345        }
 346
 347        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
 348
 349        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
 350        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
 351
 352        int qs0 = (ql >>  0) & 0x0F0F0F0F;
 353        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
 354        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
 355        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
 356        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
 357
 358        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
 359
 360        int qs1 = (ql >>  4) & 0x0F0F0F0F;
 361        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
 362        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
 363        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
 364        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 365
 366        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
 367    }
 368
 369    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
 370    const int kbxd = k % blocks_per_tile_x_row;
 371
 372#pragma unroll
 373    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
 374        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
 375
 376        if (need_check) {
 377            i = sycl::min(i, i_max);
 378        }
 379
 380        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
 381
 382        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
 383    }
 384}
 385
 386static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
 387    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 388    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 389    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 390    const int &i, const int &j, const int &k) {
 391    (void)x_qh; (void)x_sc;
 392
 393    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 394    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
 395
 396    int u[2*VDR_Q5_1_Q8_1_MMQ];
 397
 398#pragma unroll
 399    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
 400        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
 401        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
 402    }
 403
 404    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
 405        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 406}
 407
 408template <int mmq_y>
 409static __dpct_inline__ void
 410allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 411                    int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
 412    (void)x_qh; (void)x_sc;
 413
 414    *x_ql = tile_x_qs_q8_0;
 415    *x_dm = (sycl::half2 *)tile_x_d_q8_0;
 416}
 417
 418template <int mmq_y, int nwarps, bool need_check>
 419static __dpct_inline__ void
 420load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
 421                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 422                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 423                const int &k, const int &blocks_per_row) {
 424    (void)x_qh; (void)x_sc;
 425
 426    GGML_SYCL_ASSUME(i_offset >= 0);
 427    GGML_SYCL_ASSUME(i_offset <  nwarps);
 428    GGML_SYCL_ASSUME(k >= 0);
 429    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 430
 431    const int kbx  = k / QI8_0;
 432    const int kqsx = k % QI8_0;
 433    float * x_dmf = (float *) x_dm;
 434
 435    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
 436
 437#pragma unroll
 438    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 439        int i = i0 + i_offset;
 440
 441        if (need_check) {
 442            i = sycl::min(i, i_max);
 443        }
 444
 445        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
 446
 447        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
 448    }
 449
 450    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
 451    const int kbxd = k % blocks_per_tile_x_row;
 452
 453#pragma unroll
 454    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
 455        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
 456
 457        if (need_check) {
 458            i = sycl::min(i, i_max);
 459        }
 460
 461        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 462
 463        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
 464    }
 465}
 466
 467static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
 468    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 469    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 470    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 471    const int &i, const int &j, const int &k) {
 472    (void)x_qh; (void)x_sc;
 473
 474    const float * x_dmf = (const float *) x_dm;
 475    const float * y_df  = (const float *) y_ds;
 476
 477    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
 478        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
 479         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
 480}
 481
 482template <int mmq_y>
 483static __dpct_inline__ void
 484allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 485                    int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
 486                    int *tile_x_sc_q2_K) {
 487    (void)x_qh;
 488
 489    *x_ql = tile_x_ql_q2_K;
 490    *x_dm = tile_x_dm_q2_K;
 491    *x_sc = tile_x_sc_q2_K;
 492}
 493
 494template <int mmq_y, int nwarps, bool need_check>
 495static __dpct_inline__ void
 496load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
 497                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 498                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 499                const int &k, const int &blocks_per_row) {
 500    (void)x_qh;
 501
 502    GGML_SYCL_ASSUME(i_offset >= 0);
 503    GGML_SYCL_ASSUME(i_offset <  nwarps);
 504    GGML_SYCL_ASSUME(k >= 0);
 505    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 506
 507    const int kbx  = k / QI2_K;
 508    const int kqsx = k % QI2_K;
 509
 510    const block_q2_K * bx0 = (const block_q2_K *) vx;
 511
 512#pragma unroll
 513    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 514        int i = i0 + i_offset;
 515
 516        if (need_check) {
 517            i = sycl::min(i, i_max);
 518        }
 519
 520        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
 521
 522        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
 523    }
 524
 525    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
 526    const int kbxd = k % blocks_per_tile_x_row;
 527
 528#pragma unroll
 529    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
 530        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
 531
 532        if (need_check) {
 533            i = sycl::min(i, i_max);
 534        }
 535
 536        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
 537
 538        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
 539    }
 540
 541#pragma unroll
 542    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
 543        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
 544
 545        if (need_check) {
 546            i = sycl::min(i, i_max);
 547        }
 548
 549        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
 550
 551        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
 552    }
 553}
 554
 555#define VDR_Q2_K_Q8_1_MMQ  2
 556// contiguous u/y values
 557static __dpct_inline__ float
 558vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
 559                           const uint8_t *__restrict__ scales,
 560                           const sycl::half2 &dm2, const float &d8) {
 561
 562    int sumi_d = 0;
 563    int sumi_m = 0;
 564
 565#pragma unroll
 566    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
 567        int sumi_d_sc = 0;
 568
 569        const int sc = scales[i0 / (QI8_1/2)];
 570
 571        // fill int with 4x m
 572        int m = sc >> 4;
 573        m |= m <<  8;
 574        m |= m << 16;
 575
 576#pragma unroll
 577        for (int i = i0; i < i0 + QI8_1/2; ++i) {
 578            sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
 579            sumi_m = dpct::dp4a(m, u[i],
 580                                sumi_m); // multiply sum of q8_1 values with m
 581        }
 582
 583        sumi_d += sumi_d_sc * (sc & 0xF);
 584    }
 585
 586    const sycl::float2 dm2f =
 587        dm2.convert<float, sycl::rounding_mode::automatic>();
 588
 589    return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
 590}
 591
 592static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
 593    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 594    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 595    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 596    const int &i, const int &j, const int &k) {
 597    (void)x_qh;
 598
 599    const int kbx = k / QI2_K;
 600    const int ky  = (k % QI2_K) * QR2_K;
 601    const float * y_df = (const float *) y_ds;
 602
 603    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
 604
 605    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
 606    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
 607
 608#pragma unroll
 609    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
 610        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
 611    }
 612
 613    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
 614
 615    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
 616    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
 617}
 618
 619template <int mmq_y>
 620static __dpct_inline__ void
 621allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 622                    int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
 623                    int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
 624
 625    *x_ql = tile_x_ql_q3_K;
 626    *x_dm = tile_x_dm_q3_K;
 627    *x_qh = tile_x_qh_q3_K;
 628    *x_sc = tile_x_sc_q3_K;
 629}
 630
 631template <int mmq_y, int nwarps, bool need_check>
 632static __dpct_inline__ void
 633load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
 634                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 635                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 636                const int &k, const int &blocks_per_row) {
 637
 638    GGML_SYCL_ASSUME(i_offset >= 0);
 639    GGML_SYCL_ASSUME(i_offset <  nwarps);
 640    GGML_SYCL_ASSUME(k >= 0);
 641    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 642
 643    const int kbx  = k / QI3_K;
 644    const int kqsx = k % QI3_K;
 645
 646    const block_q3_K * bx0 = (const block_q3_K *) vx;
 647
 648#pragma unroll
 649    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 650        int i = i0 + i_offset;
 651
 652        if (need_check) {
 653            i = sycl::min(i, i_max);
 654        }
 655
 656        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
 657
 658        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
 659    }
 660
 661    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
 662    const int kbxd = k % blocks_per_tile_x_row;
 663    float * x_dmf = (float *) x_dm;
 664
 665#pragma unroll
 666    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
 667        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
 668
 669        if (need_check) {
 670            i = sycl::min(i, i_max);
 671        }
 672
 673        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
 674
 675        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
 676    }
 677
 678#pragma unroll
 679    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
 680        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
 681
 682        if (need_check) {
 683            i = sycl::min(i, i_max);
 684        }
 685
 686        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
 687
 688        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
 689        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
 690    }
 691
 692#pragma unroll
 693    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
 694        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
 695
 696        if (need_check) {
 697            i = sycl::min(i, i_max);
 698        }
 699
 700        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
 701
 702        const int ksc = k % (QI3_K/4);
 703
 704        const int ksc_low = ksc % (QI3_K/8);
 705        const int shift_low = 4 * (ksc / (QI3_K/8));
 706        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
 707
 708        const int ksc_high = QI3_K/8;
 709        const int shift_high = 2 * ksc;
 710        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
 711
 712        const int sc = dpct::vectorized_binary<sycl::char4>(
 713            sc_low | sc_high, 0x20202020, dpct::sub_sat());
 714
 715        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
 716    }
 717}
 718
 719#define VDR_Q3_K_Q8_1_MMQ  2
 720// contiguous u/y values
 721static __dpct_inline__ float
 722vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
 723                           const int8_t *__restrict__ scales, const float &d3,
 724                           const float &d8) {
 725
 726    int sumi = 0;
 727
 728#pragma unroll
 729    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
 730        int sumi_sc = 0;
 731
 732        for (int i = i0; i < i0 + QI8_1/2; ++i) {
 733            sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
 734        }
 735
 736        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
 737    }
 738
 739    return d3*d8 * sumi;
 740}
 741
 742static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
 743    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 744    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 745    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 746    const int &i, const int &j, const int &k) {
 747
 748    const int kbx  = k / QI3_K;
 749    const int ky  = (k % QI3_K) * QR3_K;
 750    const float * x_dmf = (const float *) x_dm;
 751    const float * y_df  = (const float *) y_ds;
 752
 753    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 754
 755    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
 756
 757#pragma unroll
 758    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
 759        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
 760        const int shift = 2 * ((ky % 32) / 8);
 761        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
 762
 763        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
 764        const int vlh = (vh << 2) & 0x04040404;
 765
 766        v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
 767    }
 768
 769    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
 770    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
 771}
 772
 773template <int mmq_y>
 774static __dpct_inline__ void
 775allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 776                    int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
 777                    int *tile_x_sc_q4_K) {
 778    (void)x_qh;
 779
 780    *x_ql = tile_x_ql_q4_K;
 781    *x_dm = tile_x_dm_q4_K;
 782    *x_sc = tile_x_sc_q4_K;
 783}
 784
 785template <int mmq_y, int nwarps, bool need_check>
 786static __dpct_inline__ void
 787load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
 788                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 789                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 790                const int &k, const int &blocks_per_row) {
 791    (void)x_qh;
 792
 793    GGML_SYCL_ASSUME(i_offset >= 0);
 794    GGML_SYCL_ASSUME(i_offset <  nwarps);
 795    GGML_SYCL_ASSUME(k >= 0);
 796    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 797
 798    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
 799    const int kqsx = k % QI4_K; // == k if QK_K == 256
 800
 801    const block_q4_K * bx0 = (const block_q4_K *) vx;
 802
 803#pragma unroll
 804    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 805        int i = i0 + i_offset;
 806
 807        if (need_check) {
 808            i = sycl::min(i, i_max);
 809        }
 810
 811        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
 812
 813        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
 814    }
 815
 816    constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
 817    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 818
 819#pragma unroll
 820    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
 821        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
 822
 823        if (need_check) {
 824            i = sycl::min(i, i_max);
 825        }
 826
 827        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
 828
 829#if QK_K == 256
 830        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
 831#else
 832        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
 833#endif
 834    }
 835
 836#pragma unroll
 837    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
 838        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
 839
 840        if (need_check) {
 841            i = sycl::min(i, i_max);
 842        }
 843
 844        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
 845
 846        const int * scales = (const int *) bxi->scales;
 847
 848        const int ksc = k % (WARP_SIZE/8);
 849
 850        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
 851        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
 852        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 853
 854        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
 855    }
 856}
 857
 858
 859#define VDR_Q4_K_Q8_1_MMQ  8
 860
 861// contiguous u/y values
 862static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
 863    const int *__restrict__ v, const int *__restrict__ u,
 864    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
 865    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
 866
 867    float sumf_d = 0.0f;
 868    float sumf_m = 0.0f;
 869
 870#pragma unroll
 871    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
 872        int sumi_d = 0;
 873
 874#pragma unroll
 875        for (int j = 0; j < QI8_1; ++j) {
 876            sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
 877                                u[i * QI8_1 + j], sumi_d); // SIMD dot product
 878        }
 879
 880        const sycl::float2 ds8f =
 881            ds8[i].convert<float, sycl::rounding_mode::automatic>();
 882
 883        sumf_d += ds8f.x() * (sc[i] * sumi_d);
 884        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
 885    }
 886
 887    const sycl::float2 dm4f =
 888        dm4.convert<float, sycl::rounding_mode::automatic>();
 889
 890    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
 891}
 892
 893
 894static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
 895    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
 896    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
 897    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
 898    const int &i, const int &j, const int &k) {
 899    (void)x_qh;
 900
 901    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 902
 903    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
 904    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
 905                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
 906}
 907
 908template <int mmq_y>
 909static __dpct_inline__ void
 910allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
 911                    int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
 912                    int *tile_x_sc_q5_K) {
 913    (void)x_qh;
 914
 915    *x_ql = tile_x_ql_q5_K;
 916    *x_dm = tile_x_dm_q5_K;
 917    *x_sc = tile_x_sc_q5_K;
 918}
 919
 920template <int mmq_y, int nwarps, bool need_check>
 921static __dpct_inline__ void
 922load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
 923                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
 924                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
 925                const int &k, const int &blocks_per_row) {
 926    (void)x_qh;
 927
 928    GGML_SYCL_ASSUME(i_offset >= 0);
 929    GGML_SYCL_ASSUME(i_offset <  nwarps);
 930    GGML_SYCL_ASSUME(k >= 0);
 931    GGML_SYCL_ASSUME(k <  WARP_SIZE);
 932
 933    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
 934    const int kqsx = k % QI5_K; // == k if QK_K == 256
 935
 936    const block_q5_K * bx0 = (const block_q5_K *) vx;
 937
 938#pragma unroll
 939    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 940        int i = i0 + i_offset;
 941
 942        if (need_check) {
 943            i = sycl::min(i, i_max);
 944        }
 945
 946        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
 947        const int ky = QR5_K*kqsx;
 948
 949        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
 950        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
 951        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 952
 953        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
 954        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
 955        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
 956
 957        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
 958        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
 959
 960        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
 961        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
 962    }
 963
 964    constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
 965    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 966
 967#pragma unroll
 968    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
 969        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
 970
 971        if (need_check) {
 972            i = sycl::min(i, i_max);
 973        }
 974
 975        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
 976
 977#if QK_K == 256
 978        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
 979#endif
 980    }
 981
 982#pragma unroll
 983    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
 984        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
 985
 986        if (need_check) {
 987            i = sycl::min(i, i_max);
 988        }
 989
 990        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
 991
 992        const int * scales = (const int *) bxi->scales;
 993
 994        const int ksc = k % (WARP_SIZE/8);
 995
 996        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
 997        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
 998        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 999
1000        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
1001    }
1002}
1003
1004#define VDR_Q5_K_Q8_1_MMQ  8
1005
1006// contiguous u/y values
1007static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
1008    const int *__restrict__ v, const int *__restrict__ u,
1009    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
1010    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
1011
1012    float sumf_d = 0.0f;
1013    float sumf_m = 0.0f;
1014
1015#pragma unroll
1016    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1017        int sumi_d = 0;
1018
1019#pragma unroll
1020        for (int j = 0; j < QI8_1; ++j) {
1021            sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
1022                                sumi_d); // SIMD dot product
1023        }
1024
1025        const sycl::float2 ds8f =
1026            ds8[i].convert<float, sycl::rounding_mode::automatic>();
1027
1028        sumf_d += ds8f.x() * (sc[i] * sumi_d);
1029        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
1030    }
1031
1032    const sycl::float2 dm4f =
1033        dm4.convert<float, sycl::rounding_mode::automatic>();
1034
1035    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
1036}
1037
1038static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
1039    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1040    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1041    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1042    const int &i, const int &j, const int &k) {
1043    (void)x_qh;
1044
1045    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
1046
1047    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
1048    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
1049    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
1050                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
1051}
1052
1053template <int mmq_y>
1054static __dpct_inline__ void
1055allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
1056                    int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
1057    (void)x_qh;
1058
1059    *x_ql = tile_x_ql;
1060    *x_dm = tile_x_dm;
1061    *x_sc = tile_x_sc;
1062}
1063
1064template <int mmq_y, int nwarps, bool need_check>
1065static __dpct_inline__ void
1066load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
1067                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
1068                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
1069                const int &k, const int &blocks_per_row) {
1070    (void)x_qh;
1071
1072    GGML_SYCL_ASSUME(i_offset >= 0);
1073    GGML_SYCL_ASSUME(i_offset <  nwarps);
1074    GGML_SYCL_ASSUME(k >= 0);
1075    GGML_SYCL_ASSUME(k <  WARP_SIZE);
1076
1077    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
1078    const int kqsx = k % QI6_K; // == k if QK_K == 256
1079
1080    const block_q6_K * bx0 = (const block_q6_K *) vx;
1081
1082#pragma unroll
1083    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1084        int i = i0 + i_offset;
1085
1086        if (need_check) {
1087            i = sycl::min(i, i_max);
1088        }
1089
1090        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
1091        const int ky = QR6_K*kqsx;
1092
1093        const int ql = get_int_from_uint8(bxi->ql, kqsx);
1094        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1095        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1096
1097        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1098        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1099        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
1100
1101        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
1102        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
1103
1104        x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
1105            dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
1106                                                 dpct::sub_sat());
1107        x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
1108            dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
1109                                                 dpct::sub_sat());
1110    }
1111
1112    constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1113    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
1114    float * x_dmf = (float *) x_dm;
1115
1116#pragma unroll
1117    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1118        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
1119
1120        if (need_check) {
1121            i = sycl::min(i, i_max);
1122        }
1123
1124        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
1125
1126        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
1127    }
1128
1129#pragma unroll
1130    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1131        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
1132
1133        if (need_check) {
1134            i = sycl::min(i, i_max);
1135        }
1136
1137        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
1138
1139        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
1140    }
1141}
1142
1143#define VDR_Q6_K_Q8_1_MMQ  8
1144
1145// contiguous u/y values
1146static __dpct_inline__ float
1147vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
1148                           const int8_t *__restrict__ sc, const float &d6,
1149                           const float *__restrict__ d8) {
1150
1151    float sumf_d = 0.0f;
1152
1153#pragma unroll
1154    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
1155        sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
1156
1157#pragma unroll
1158        for (int i = i0; i < i0 + 2; ++i) {
1159            sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
1160                                    sumi_d.x()); // SIMD dot product
1161            sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
1162                                    sumi_d.x()); // SIMD dot product
1163
1164            sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
1165                                    sumi_d.y()); // SIMD dot product
1166            sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
1167                                    sumi_d.y()); // SIMD dot product
1168        }
1169
1170        sumf_d += d8[i0 / 4] *
1171                  (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
1172    }
1173
1174    return d6 * sumf_d;
1175}
1176
1177static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
1178    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1179    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1180    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1181    const int &i, const int &j, const int &k) {
1182    (void)x_qh;
1183
1184    const float * x_dmf = (const float *) x_dm;
1185    const float * y_df  = (const float *) y_ds;
1186
1187    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
1188
1189    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
1190    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
1191    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
1192}
1193
1194template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
1195          int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
1196          vec_dot_q_mul_mat_sycl_t vec_dot>
1197/*
1198DPCT1110:8: The total declared local variable size in device function mul_mat_q
1199exceeds 128 bytes and may cause high register pressure. Consult with your
1200hardware vendor to find the total register size available and adjust the code,
1201or use smaller sub-group size to avoid high register pressure.
1202*/
1203static __dpct_inline__ void
1204mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
1205          float *__restrict__ dst, const int ncols_x, const int nrows_x,
1206          const int ncols_y, const int nrows_y, const int nrows_dst,
1207          int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
1208          int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
1209          sycl::half2 *tile_y_ds) {
1210
1211    const block_q_t  * x = (const block_q_t  *) vx;
1212    const block_q8_1 * y = (const block_q8_1 *) vy;
1213
1214    const int blocks_per_row_x = ncols_x / qk;
1215    const int blocks_per_col_y = nrows_y / QK8_1;
1216    const int blocks_per_warp = WARP_SIZE / qi;
1217
1218    const int & ncols_dst = ncols_y;
1219
1220    const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
1221    const int & row_x_0 = row_dst_0;
1222
1223    const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
1224    const int & col_y_0 = col_dst_0;
1225
1226    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
1227
1228    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1229
1230        load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
1231                   tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
1232                   nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
1233                   blocks_per_row_x);
1234
1235#pragma unroll
1236        for (int ir = 0; ir < qr; ++ir) {
1237            const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
1238            const int kbxd = kqs / QI8_1;
1239
1240#pragma unroll
1241            for (int i = 0; i < mmq_x; i += nwarps) {
1242                const int col_y_eff = dpct::min(
1243                    (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
1244                    ncols_y - 1); // to prevent out-of-bounds memory accesses
1245
1246                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1247
1248                const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
1249                                    kqs % WARP_SIZE;
1250                tile_y_qs[index_y] = get_int_from_int8_aligned(
1251                    by0->qs, item_ct1.get_local_id(2) % QI8_1);
1252            }
1253
1254#pragma unroll
1255            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1256                const int ids =
1257                    (ids0 + item_ct1.get_local_id(1) * QI8_1 +
1258                     item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
1259                    mmq_x;
1260                const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
1261                const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
1262
1263                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1264                const sycl::half2 *dsi_src =
1265                    &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
1266                       ir * (WARP_SIZE / QI8_1) + kby]
1267                         .ds;
1268                sycl::half2 *dsi_dst =
1269                    &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
1270                if (need_sum) {
1271                    *dsi_dst = *dsi_src;
1272                } else {
1273                    float * dfi_dst = (float *) dsi_dst;
1274                    *dfi_dst = (*dsi_src)[0];
1275                }
1276            }
1277
1278            /*
1279            DPCT1118:9: SYCL group functions and algorithms must be encountered
1280            in converged control flow. You may need to adjust the code.
1281            */
1282            /*
1283            DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
1284            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1285            better performance if there is no access to global memory.
1286            */
1287            item_ct1.barrier();
1288
1289// #pragma unroll // unrolling this loop causes too much register pressure
1290            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
1291#pragma unroll
1292                for (int j = 0; j < mmq_x; j += nwarps) {
1293#pragma unroll
1294                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1295                        sum[i / WARP_SIZE][j / nwarps] += vec_dot(
1296                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1297                            tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
1298                            item_ct1.get_local_id(1) + j, k);
1299                    }
1300                }
1301            }
1302
1303            /*
1304            DPCT1118:10: SYCL group functions and algorithms must be encountered
1305            in converged control flow. You may need to adjust the code.
1306            */
1307            /*
1308            DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
1309            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1310            better performance if there is no access to global memory.
1311            */
1312            item_ct1.barrier();
1313        }
1314    }
1315
1316#pragma unroll
1317    for (int j = 0; j < mmq_x; j += nwarps) {
1318        const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
1319
1320        if (col_dst >= ncols_dst) {
1321            return;
1322        }
1323
1324#pragma unroll
1325        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1326            const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
1327
1328            if (row_dst >= nrows_dst) {
1329                continue;
1330            }
1331
1332            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
1333        }
1334    }
1335}
1336
1337#define  MMQ_X_Q4_0_RDNA2  64
1338#define  MMQ_Y_Q4_0_RDNA2  128
1339#define NWARPS_Q4_0_RDNA2  8
1340#define  MMQ_X_Q4_0_RDNA1  64
1341#define  MMQ_Y_Q4_0_RDNA1  64
1342#define NWARPS_Q4_0_RDNA1  8
1343#if defined(SYCL_USE_XMX)
1344#define  MMQ_X_Q4_0_AMPERE 4
1345#define  MMQ_Y_Q4_0_AMPERE 32
1346#define NWARPS_Q4_0_AMPERE 4
1347#else
1348#define  MMQ_X_Q4_0_AMPERE 64
1349#define  MMQ_Y_Q4_0_AMPERE 128
1350#define NWARPS_Q4_0_AMPERE 4
1351#endif
1352#define  MMQ_X_Q4_0_PASCAL 64
1353#define  MMQ_Y_Q4_0_PASCAL 64
1354#define NWARPS_Q4_0_PASCAL 8
1355
1356template <bool need_check> static void
1357    mul_mat_q4_0(
1358    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1359    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1360    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
1361    int *tile_y_qs, sycl::half2 *tile_y_ds) {
1362    int   * tile_x_ql = nullptr;
1363    sycl::half2 *tile_x_dm = nullptr;
1364    int   * tile_x_qh = nullptr;
1365    int   * tile_x_sc = nullptr;
1366
1367//sycl_todo: change according to hardware
1368
1369    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
1370    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
1371    const int nwarps = NWARPS_Q4_0_AMPERE;
1372    allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1373                               tile_x_qs_q4_0, tile_x_d_q4_0);
1374    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
1375              load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
1376              vec_dot_q4_0_q8_1_mul_mat>(
1377        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1378        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1379}
1380
1381#define  MMQ_X_Q4_1_RDNA2  64
1382#define  MMQ_Y_Q4_1_RDNA2  128
1383#define NWARPS_Q4_1_RDNA2  8
1384#define  MMQ_X_Q4_1_RDNA1  64
1385#define  MMQ_Y_Q4_1_RDNA1  64
1386#define NWARPS_Q4_1_RDNA1  8
1387#if defined(SYCL_USE_XMX)
1388#define  MMQ_X_Q4_1_AMPERE 4
1389#define  MMQ_Y_Q4_1_AMPERE 32
1390#define NWARPS_Q4_1_AMPERE 4
1391#else
1392#define  MMQ_X_Q4_1_AMPERE 64
1393#define  MMQ_Y_Q4_1_AMPERE 128
1394#define NWARPS_Q4_1_AMPERE 4
1395#endif
1396#define  MMQ_X_Q4_1_PASCAL 64
1397#define  MMQ_Y_Q4_1_PASCAL 64
1398#define NWARPS_Q4_1_PASCAL 8
1399
1400template <bool need_check> static void
1401    mul_mat_q4_1(
1402    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1403    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1404    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
1405    sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1406    int   * tile_x_ql = nullptr;
1407    sycl::half2 *tile_x_dm = nullptr;
1408    int   * tile_x_qh = nullptr;
1409    int   * tile_x_sc = nullptr;
1410
1411//sycl_todo: change according to hardware
1412    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
1413    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
1414    const int nwarps = NWARPS_Q4_1_AMPERE;
1415    allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1416                               tile_x_qs_q4_1, tile_x_dm_q4_1);
1417    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
1418              load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
1419              vec_dot_q4_1_q8_1_mul_mat>(
1420        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1421        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1422}
1423
1424#define  MMQ_X_Q5_0_RDNA2  64
1425#define  MMQ_Y_Q5_0_RDNA2  128
1426#define NWARPS_Q5_0_RDNA2  8
1427#define  MMQ_X_Q5_0_RDNA1  64
1428#define  MMQ_Y_Q5_0_RDNA1  64
1429#define NWARPS_Q5_0_RDNA1  8
1430#if defined(SYCL_USE_XMX)
1431#define  MMQ_X_Q5_0_AMPERE 4
1432#define  MMQ_Y_Q5_0_AMPERE 32
1433#define NWARPS_Q5_0_AMPERE 4
1434#else
1435#define  MMQ_X_Q5_0_AMPERE 128
1436#define  MMQ_Y_Q5_0_AMPERE 64
1437#define NWARPS_Q5_0_AMPERE 4
1438#endif
1439#define  MMQ_X_Q5_0_PASCAL 64
1440#define  MMQ_Y_Q5_0_PASCAL 64
1441#define NWARPS_Q5_0_PASCAL 8
1442
1443template <bool need_check> static void
1444    mul_mat_q5_0(
1445    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1446    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1447    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
1448    int *tile_y_qs, sycl::half2 *tile_y_ds) {
1449    int   * tile_x_ql = nullptr;
1450    sycl::half2 *tile_x_dm = nullptr;
1451    int   * tile_x_qh = nullptr;
1452    int   * tile_x_sc = nullptr;
1453
1454//sycl_todo: change according to hardware
1455    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
1456    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
1457    const int nwarps = NWARPS_Q5_0_AMPERE;
1458    allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1459                               tile_x_ql_q5_0, tile_x_d_q5_0);
1460    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
1461              load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
1462              vec_dot_q5_0_q8_1_mul_mat>(
1463        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1464        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1465}
1466
1467#define  MMQ_X_Q5_1_RDNA2  64
1468#define  MMQ_Y_Q5_1_RDNA2  128
1469#define NWARPS_Q5_1_RDNA2  8
1470#define  MMQ_X_Q5_1_RDNA1  64
1471#define  MMQ_Y_Q5_1_RDNA1  64
1472#define NWARPS_Q5_1_RDNA1  8
1473#if defined(SYCL_USE_XMX)
1474#define  MMQ_X_Q5_1_AMPERE 4
1475#define  MMQ_Y_Q5_1_AMPERE 32
1476#define NWARPS_Q5_1_AMPERE 4
1477#else
1478#define  MMQ_X_Q5_1_AMPERE 128
1479#define  MMQ_Y_Q5_1_AMPERE 64
1480#define NWARPS_Q5_1_AMPERE 4
1481#endif
1482#define  MMQ_X_Q5_1_PASCAL 64
1483#define  MMQ_Y_Q5_1_PASCAL 64
1484#define NWARPS_Q5_1_PASCAL 8
1485
1486template <bool need_check> static void
1487mul_mat_q5_1(
1488    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1489    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1490    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
1491    sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1492    int   * tile_x_ql = nullptr;
1493    sycl::half2 *tile_x_dm = nullptr;
1494    int   * tile_x_qh = nullptr;
1495    int   * tile_x_sc = nullptr;
1496
1497//sycl_todo: change according to hardware
1498    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
1499    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
1500    const int nwarps = NWARPS_Q5_1_AMPERE;
1501    allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1502                               tile_x_ql_q5_1, tile_x_dm_q5_1);
1503    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
1504              load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
1505              vec_dot_q5_1_q8_1_mul_mat>(
1506        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1507        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1508}
1509
1510#define  MMQ_X_Q8_0_RDNA2  64
1511#define  MMQ_Y_Q8_0_RDNA2  128
1512#define NWARPS_Q8_0_RDNA2  8
1513#define  MMQ_X_Q8_0_RDNA1  64
1514#define  MMQ_Y_Q8_0_RDNA1  64
1515#define NWARPS_Q8_0_RDNA1  8
1516#if defined(SYCL_USE_XMX)
1517#define  MMQ_X_Q8_0_AMPERE 4
1518#define  MMQ_Y_Q8_0_AMPERE 32
1519#define NWARPS_Q8_0_AMPERE 4
1520#else
1521#define  MMQ_X_Q8_0_AMPERE 128
1522#define  MMQ_Y_Q8_0_AMPERE 64
1523#define NWARPS_Q8_0_AMPERE 4
1524#endif
1525#define  MMQ_X_Q8_0_PASCAL 64
1526#define  MMQ_Y_Q8_0_PASCAL 64
1527#define NWARPS_Q8_0_PASCAL 8
1528
1529template <bool need_check> static void
1530    mul_mat_q8_0(
1531    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1532    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1533    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
1534    int *tile_y_qs, sycl::half2 *tile_y_ds) {
1535    int   * tile_x_ql = nullptr;
1536    sycl::half2 *tile_x_dm = nullptr;
1537    int   * tile_x_qh = nullptr;
1538    int   * tile_x_sc = nullptr;
1539
1540//sycl_todo: change according to hardware
1541    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
1542    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
1543    const int nwarps = NWARPS_Q8_0_AMPERE;
1544    allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1545                               tile_x_qs_q8_0, tile_x_d_q8_0);
1546    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
1547              load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
1548              vec_dot_q8_0_q8_1_mul_mat>(
1549        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1550        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1551}
1552
1553#define  MMQ_X_Q2_K_RDNA2  64
1554#define  MMQ_Y_Q2_K_RDNA2  128
1555#define NWARPS_Q2_K_RDNA2  8
1556#define  MMQ_X_Q2_K_RDNA1  128
1557#define  MMQ_Y_Q2_K_RDNA1  32
1558#define NWARPS_Q2_K_RDNA1  8
1559#if defined(SYCL_USE_XMX)
1560#define  MMQ_X_Q2_K_AMPERE 4
1561#define  MMQ_Y_Q2_K_AMPERE 32
1562#define NWARPS_Q2_K_AMPERE 4
1563#else
1564#define  MMQ_X_Q2_K_AMPERE 64
1565#define  MMQ_Y_Q2_K_AMPERE 128
1566#define NWARPS_Q2_K_AMPERE 4
1567#endif
1568#define  MMQ_X_Q2_K_PASCAL 64
1569#define  MMQ_Y_Q2_K_PASCAL 64
1570#define NWARPS_Q2_K_PASCAL 8
1571
1572template <bool need_check> static void
1573mul_mat_q2_K(
1574    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1575    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1576    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
1577    sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
1578    sycl::half2 *tile_y_ds) {
1579    int   * tile_x_ql = nullptr;
1580    sycl::half2 *tile_x_dm = nullptr;
1581    int   * tile_x_qh = nullptr;
1582    int   * tile_x_sc = nullptr;
1583
1584//sycl_todo: change according to hardware
1585    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
1586    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
1587    const int nwarps = NWARPS_Q2_K_AMPERE;
1588    allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1589                               tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
1590    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
1591              load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
1592              vec_dot_q2_K_q8_1_mul_mat>(
1593        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1594        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1595}
1596
1597#define  MMQ_X_Q3_K_RDNA2  128
1598#define  MMQ_Y_Q3_K_RDNA2  64
1599#define NWARPS_Q3_K_RDNA2  8
1600#define  MMQ_X_Q3_K_RDNA1  32
1601#define  MMQ_Y_Q3_K_RDNA1  128
1602#define NWARPS_Q3_K_RDNA1  8
1603#if defined(SYCL_USE_XMX)
1604#define  MMQ_X_Q3_K_AMPERE 4
1605#define  MMQ_Y_Q3_K_AMPERE 32
1606#define NWARPS_Q3_K_AMPERE 4
1607#else
1608#define  MMQ_X_Q3_K_AMPERE 128
1609#define  MMQ_Y_Q3_K_AMPERE 128
1610#define NWARPS_Q3_K_AMPERE 4
1611#endif
1612#define  MMQ_X_Q3_K_PASCAL 64
1613#define  MMQ_Y_Q3_K_PASCAL 64
1614#define NWARPS_Q3_K_PASCAL 8
1615
1616template <bool need_check> static void
1617mul_mat_q3_K(
1618    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1619    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1620    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
1621    sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
1622    int *tile_y_qs, sycl::half2 *tile_y_ds) {
1623    int   * tile_x_ql = nullptr;
1624    sycl::half2 *tile_x_dm = nullptr;
1625    int   * tile_x_qh = nullptr;
1626    int   * tile_x_sc = nullptr;
1627
1628//sycl_todo: change according to hardware
1629    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
1630    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
1631    const int nwarps = NWARPS_Q3_K_AMPERE;
1632    allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1633                               tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
1634                               tile_x_sc_q3_K);
1635    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
1636              load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
1637              vec_dot_q3_K_q8_1_mul_mat>(
1638        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1639        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1640}
1641
1642#define  MMQ_X_Q4_K_RDNA2  64
1643#define  MMQ_Y_Q4_K_RDNA2  128
1644#define NWARPS_Q4_K_RDNA2  8
1645#define  MMQ_X_Q4_K_RDNA1  32
1646#define  MMQ_Y_Q4_K_RDNA1  64
1647#define NWARPS_Q4_K_RDNA1  8
1648#if defined(SYCL_USE_XMX)
1649#define  MMQ_X_Q4_K_AMPERE 4
1650#define  MMQ_Y_Q4_K_AMPERE 32
1651#define NWARPS_Q4_K_AMPERE 4
1652#else
1653#define  MMQ_X_Q4_K_AMPERE 64
1654#define  MMQ_Y_Q4_K_AMPERE 128
1655#define NWARPS_Q4_K_AMPERE 4
1656#endif
1657#define  MMQ_X_Q4_K_PASCAL 64
1658#define  MMQ_Y_Q4_K_PASCAL 64
1659#define NWARPS_Q4_K_PASCAL 8
1660
1661template <bool need_check> static void
1662    mul_mat_q4_K(
1663    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1664    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1665    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
1666    sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
1667    sycl::half2 *tile_y_ds) {
1668    int   * tile_x_ql = nullptr;
1669    sycl::half2 *tile_x_dm = nullptr;
1670    int   * tile_x_qh = nullptr;
1671    int   * tile_x_sc = nullptr;
1672
1673//sycl_todo: change according to hardware
1674    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
1675    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
1676    const int nwarps = NWARPS_Q4_K_AMPERE;
1677    allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1678                               tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
1679    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
1680              load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
1681              vec_dot_q4_K_q8_1_mul_mat>(
1682        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1683        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1684}
1685
1686#define  MMQ_X_Q5_K_RDNA2  64
1687#define  MMQ_Y_Q5_K_RDNA2  128
1688#define NWARPS_Q5_K_RDNA2  8
1689#define  MMQ_X_Q5_K_RDNA1  32
1690#define  MMQ_Y_Q5_K_RDNA1  64
1691#define NWARPS_Q5_K_RDNA1  8
1692#if defined(SYCL_USE_XMX)
1693#define  MMQ_X_Q5_K_AMPERE 4
1694#define  MMQ_Y_Q5_K_AMPERE 32
1695#define NWARPS_Q5_K_AMPERE 4
1696#else
1697#define  MMQ_X_Q5_K_AMPERE 64
1698#define  MMQ_Y_Q5_K_AMPERE 128
1699#define NWARPS_Q5_K_AMPERE 4
1700#endif
1701#define  MMQ_X_Q5_K_PASCAL 64
1702#define  MMQ_Y_Q5_K_PASCAL 64
1703#define NWARPS_Q5_K_PASCAL 8
1704
1705template <bool need_check> static void
1706mul_mat_q5_K(
1707    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1708    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1709    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
1710    sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
1711    sycl::half2 *tile_y_ds) {
1712    int   * tile_x_ql = nullptr;
1713    sycl::half2 *tile_x_dm = nullptr;
1714    int   * tile_x_qh = nullptr;
1715    int   * tile_x_sc = nullptr;
1716
1717//sycl_todo: change according to hardware
1718    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
1719    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
1720    const int nwarps = NWARPS_Q5_K_AMPERE;
1721    allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1722                               tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
1723    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
1724              load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
1725              vec_dot_q5_K_q8_1_mul_mat>(
1726        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1727        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1728}
1729
1730#define  MMQ_X_Q6_K_RDNA2  64
1731#define  MMQ_Y_Q6_K_RDNA2  128
1732#define NWARPS_Q6_K_RDNA2  8
1733#define  MMQ_X_Q6_K_RDNA1  32
1734#define  MMQ_Y_Q6_K_RDNA1  64
1735#define NWARPS_Q6_K_RDNA1  8
1736#if defined(SYCL_USE_XMX)
1737#define  MMQ_X_Q6_K_AMPERE 4
1738#define  MMQ_Y_Q6_K_AMPERE 32
1739#define NWARPS_Q6_K_AMPERE 4
1740#else
1741#define  MMQ_X_Q6_K_AMPERE 64
1742#define  MMQ_Y_Q6_K_AMPERE 64
1743#define NWARPS_Q6_K_AMPERE 4
1744#endif
1745#define  MMQ_X_Q6_K_PASCAL 64
1746#define  MMQ_Y_Q6_K_PASCAL 64
1747#define NWARPS_Q6_K_PASCAL 8
1748
1749template <bool need_check> static void
1750    mul_mat_q6_K(
1751    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1752    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1753    const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
1754    int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1755    // int   * tile_x_ql = nullptr;
1756    // sycl::half2 *tile_x_dm = nullptr;
1757    int   * tile_x_qh = nullptr;
1758    // int   * tile_x_sc = nullptr;
1759
1760//sycl_todo: change according to hardware
1761    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
1762    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
1763    const int nwarps = NWARPS_Q6_K_AMPERE;
1764    allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1765                               tile_x_ql, tile_x_dm, tile_x_sc);
1766    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
1767              load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
1768              vec_dot_q6_K_q8_1_mul_mat>(
1769        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1770        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1771}
1772
1773static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
1774                                        float *dst, const int ncols_x,
1775                                        const int nrows_x, const int ncols_y,
1776                                        const int nrows_y, const int nrows_dst,
1777                                        dpct::queue_ptr stream) try {
1778
1779    int id;
1780    SYCL_CHECK(
1781        CHECK_TRY_ERROR(id = get_current_device_id()));
1782    const int compute_capability = ggml_sycl_info().devices[id].cc;
1783
1784    int mmq_x, mmq_y, nwarps;
1785    if (compute_capability >= VER_GEN13) {
1786        mmq_x  =  MMQ_X_Q4_0_RDNA2;
1787        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
1788        nwarps = NWARPS_Q4_0_RDNA2;
1789    } else if (compute_capability >= VER_GEN12) {
1790        mmq_x  =  MMQ_X_Q4_0_RDNA1;
1791        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
1792        nwarps = NWARPS_Q4_0_RDNA1;
1793    } else if (compute_capability >= VER_GEN9) {
1794        mmq_x  =  MMQ_X_Q4_0_AMPERE;
1795        mmq_y  =  MMQ_Y_Q4_0_AMPERE;
1796        nwarps = NWARPS_Q4_0_AMPERE;
1797    } else if (compute_capability >= VER_4VEC) {
1798        mmq_x  =  MMQ_X_Q4_0_PASCAL;
1799        mmq_y  =  MMQ_Y_Q4_0_PASCAL;
1800        nwarps = NWARPS_Q4_0_PASCAL;
1801    } else {
1802        GGML_ABORT("fatal error");
1803    }
1804
1805    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1806    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1807    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1808    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1809
1810    if (nrows_x % mmq_y == 0) {
1811        const bool need_check = false;
1812        /*
1813        DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
1814        the limit. To get the device limit, query
1815        info::device::max_work_group_size. Adjust the work-group size if needed.
1816        */
1817        {
1818            dpct::has_capability_or_fail(stream->get_device(),
1819                                         {sycl::aspect::fp16});
1820
1821            stream->submit([&](sycl::handler &cgh) {
1822                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1823                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1824                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1825                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1826                    cgh);
1827                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1828                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1829                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1830                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1831
1832                cgh.parallel_for(
1833                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
1834                    [=](sycl::nd_item<3> item_ct1) {
1835                        mul_mat_q4_0<need_check>(
1836                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1837                            nrows_dst, item_ct1,
1838                            get_pointer(tile_x_qs_q4_0_acc_ct1),
1839                            get_pointer(tile_x_d_q4_0_acc_ct1),
1840                            get_pointer(tile_y_qs_acc_ct1),
1841                            get_pointer(tile_y_ds_acc_ct1));
1842                    });
1843            });
1844        }
1845    } else {
1846        const bool need_check = true;
1847        /*
1848        DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
1849        the limit. To get the device limit, query
1850        info::device::max_work_group_size. Adjust the work-group size if needed.
1851        */
1852        {
1853            dpct::has_capability_or_fail(stream->get_device(),
1854                                         {sycl::aspect::fp16});
1855
1856            stream->submit([&](sycl::handler &cgh) {
1857                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1858                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1859                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1860                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1861                    cgh);
1862                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1863                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1864                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1865                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1866
1867                cgh.parallel_for(
1868                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
1869                    [=](sycl::nd_item<3> item_ct1) {
1870                        mul_mat_q4_0<need_check>(
1871                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1872                            nrows_dst, item_ct1,
1873                            get_pointer(tile_x_qs_q4_0_acc_ct1),
1874                            get_pointer(tile_x_d_q4_0_acc_ct1),
1875                            get_pointer(tile_y_qs_acc_ct1),
1876                            get_pointer(tile_y_ds_acc_ct1));
1877                    });
1878            });
1879        }
1880    }
1881}
1882catch (sycl::exception const &exc) {
1883  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1884            << ", line:" << __LINE__ << std::endl;
1885  std::exit(1);
1886}
1887
1888static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
1889                                        float *dst, const int ncols_x,
1890                                        const int nrows_x, const int ncols_y,
1891                                        const int nrows_y, const int nrows_dst,
1892                                        dpct::queue_ptr stream) try {
1893
1894    int id;
1895    SYCL_CHECK(
1896        CHECK_TRY_ERROR(id = get_current_device_id()));
1897    const int compute_capability = ggml_sycl_info().devices[id].cc;
1898
1899    int mmq_x, mmq_y, nwarps;
1900    if (compute_capability >= VER_GEN13) {
1901        mmq_x  =  MMQ_X_Q4_1_RDNA2;
1902        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
1903        nwarps = NWARPS_Q4_1_RDNA2;
1904    } else if (compute_capability >= VER_GEN12) {
1905        mmq_x  =  MMQ_X_Q4_1_RDNA1;
1906        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
1907        nwarps = NWARPS_Q4_1_RDNA1;
1908    } else if (compute_capability >= VER_GEN9) {
1909        mmq_x  =  MMQ_X_Q4_1_AMPERE;
1910        mmq_y  =  MMQ_Y_Q4_1_AMPERE;
1911        nwarps = NWARPS_Q4_1_AMPERE;
1912    } else if (compute_capability >= VER_4VEC) {
1913        mmq_x  =  MMQ_X_Q4_1_PASCAL;
1914        mmq_y  =  MMQ_Y_Q4_1_PASCAL;
1915        nwarps = NWARPS_Q4_1_PASCAL;
1916    } else {
1917        GGML_ABORT("fatal error");
1918    }
1919
1920    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1921    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1922    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1923    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1924
1925    if (nrows_x % mmq_y == 0) {
1926        const bool need_check = false;
1927        /*
1928        DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
1929        the limit. To get the device limit, query
1930        info::device::max_work_group_size. Adjust the work-group size if needed.
1931        */
1932        {
1933            dpct::has_capability_or_fail(stream->get_device(),
1934                                         {sycl::aspect::fp16});
1935
1936            stream->submit([&](sycl::handler &cgh) {
1937                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1938                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1939                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1940                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1941                    cgh);
1942                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1943                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1944                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1945                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1946
1947                cgh.parallel_for(
1948                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
1949                    [=](sycl::nd_item<3> item_ct1) {
1950                        mul_mat_q4_1<need_check>(
1951                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1952                            nrows_dst, item_ct1,
1953                            get_pointer(tile_x_qs_q4_1_acc_ct1),
1954                            get_pointer(tile_x_dm_q4_1_acc_ct1),
1955                            get_pointer(tile_y_qs_acc_ct1),
1956                            get_pointer(tile_y_ds_acc_ct1));
1957                    });
1958            });
1959        }
1960    } else {
1961        const bool need_check = true;
1962        /*
1963        DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
1964        the limit. To get the device limit, query
1965        info::device::max_work_group_size. Adjust the work-group size if needed.
1966        */
1967        {
1968            dpct::has_capability_or_fail(stream->get_device(),
1969                                         {sycl::aspect::fp16});
1970
1971            stream->submit([&](sycl::handler &cgh) {
1972                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1973                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1974                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1975                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1976                    cgh);
1977                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1978                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1979                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1980                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1981
1982                cgh.parallel_for(
1983                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
1984                    [=](sycl::nd_item<3> item_ct1) {
1985                        mul_mat_q4_1<need_check>(
1986                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1987                            nrows_dst, item_ct1,
1988                            get_pointer(tile_x_qs_q4_1_acc_ct1),
1989                            get_pointer(tile_x_dm_q4_1_acc_ct1),
1990                            get_pointer(tile_y_qs_acc_ct1),
1991                            get_pointer(tile_y_ds_acc_ct1));
1992                    });
1993            });
1994        }
1995    }
1996}
1997catch (sycl::exception const &exc) {
1998  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1999            << ", line:" << __LINE__ << std::endl;
2000  std::exit(1);
2001}
2002
2003static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
2004                                        float *dst, const int ncols_x,
2005                                        const int nrows_x, const int ncols_y,
2006                                        const int nrows_y, const int nrows_dst,
2007                                        dpct::queue_ptr stream) try {
2008
2009    int id;
2010    SYCL_CHECK(
2011        CHECK_TRY_ERROR(id = get_current_device_id()));
2012    const int compute_capability = ggml_sycl_info().devices[id].cc;
2013
2014    int mmq_x, mmq_y, nwarps;
2015    if (compute_capability >= VER_GEN13) {
2016        mmq_x  =  MMQ_X_Q5_0_RDNA2;
2017        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
2018        nwarps = NWARPS_Q5_0_RDNA2;
2019    } else if (compute_capability >= VER_GEN12) {
2020        mmq_x  =  MMQ_X_Q5_0_RDNA1;
2021        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
2022        nwarps = NWARPS_Q5_0_RDNA1;
2023    } else if (compute_capability >= VER_GEN9) {
2024        mmq_x  =  MMQ_X_Q5_0_AMPERE;
2025        mmq_y  =  MMQ_Y_Q5_0_AMPERE;
2026        nwarps = NWARPS_Q5_0_AMPERE;
2027    } else if (compute_capability >= VER_4VEC) {
2028        mmq_x  =  MMQ_X_Q5_0_PASCAL;
2029        mmq_y  =  MMQ_Y_Q5_0_PASCAL;
2030        nwarps = NWARPS_Q5_0_PASCAL;
2031    } else {
2032        GGML_ABORT("fatal error");
2033    }
2034
2035    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2036    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2037    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2038    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2039
2040    if (nrows_x % mmq_y == 0) {
2041        const bool need_check = false;
2042        /*
2043        DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
2044        the limit. To get the device limit, query
2045        info::device::max_work_group_size. Adjust the work-group size if needed.
2046        */
2047        {
2048            dpct::has_capability_or_fail(stream->get_device(),
2049                                         {sycl::aspect::fp16});
2050
2051            stream->submit([&](sycl::handler &cgh) {
2052                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2053                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2054                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2055                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2056                    cgh);
2057                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2058                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2059                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2060                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2061
2062                cgh.parallel_for(
2063                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2064                    [=](sycl::nd_item<3> item_ct1) {
2065                        mul_mat_q5_0<need_check>(
2066                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2067                            nrows_dst, item_ct1,
2068                            get_pointer(tile_x_ql_q5_0_acc_ct1),
2069                            get_pointer(tile_x_d_q5_0_acc_ct1),
2070                            get_pointer(tile_y_qs_acc_ct1),
2071                            get_pointer(tile_y_ds_acc_ct1));
2072                    });
2073            });
2074        }
2075    } else {
2076        const bool need_check = true;
2077        /*
2078        DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
2079        the limit. To get the device limit, query
2080        info::device::max_work_group_size. Adjust the work-group size if needed.
2081        */
2082        {
2083            dpct::has_capability_or_fail(stream->get_device(),
2084                                         {sycl::aspect::fp16});
2085
2086            stream->submit([&](sycl::handler &cgh) {
2087                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2088                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2089                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2090                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2091                    cgh);
2092                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2093                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2094                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2095                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2096
2097                cgh.parallel_for(
2098                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2099                    [=](sycl::nd_item<3> item_ct1) {
2100                        mul_mat_q5_0<need_check>(
2101                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2102                            nrows_dst, item_ct1,
2103                            get_pointer(tile_x_ql_q5_0_acc_ct1),
2104                            get_pointer(tile_x_d_q5_0_acc_ct1),
2105                            get_pointer(tile_y_qs_acc_ct1),
2106                            get_pointer(tile_y_ds_acc_ct1));
2107                    });
2108            });
2109        }
2110    }
2111}
2112catch (sycl::exception const &exc) {
2113  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2114            << ", line:" << __LINE__ << std::endl;
2115  std::exit(1);
2116}
2117
2118static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
2119                                        float *dst, const int ncols_x,
2120                                        const int nrows_x, const int ncols_y,
2121                                        const int nrows_y, const int nrows_dst,
2122                                        dpct::queue_ptr stream) try {
2123
2124    int id;
2125    SYCL_CHECK(
2126        CHECK_TRY_ERROR(id = get_current_device_id()));
2127    const int compute_capability = ggml_sycl_info().devices[id].cc;
2128
2129    int mmq_x, mmq_y, nwarps;
2130    if (compute_capability >= VER_GEN13) {
2131        mmq_x  =  MMQ_X_Q5_1_RDNA2;
2132        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
2133        nwarps = NWARPS_Q5_1_RDNA2;
2134    } else if (compute_capability >= VER_GEN12) {
2135        mmq_x  =  MMQ_X_Q5_1_RDNA1;
2136        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
2137        nwarps = NWARPS_Q5_1_RDNA1;
2138    } else if (compute_capability >= VER_GEN9) {
2139        mmq_x  =  MMQ_X_Q5_1_AMPERE;
2140        mmq_y  =  MMQ_Y_Q5_1_AMPERE;
2141        nwarps = NWARPS_Q5_1_AMPERE;
2142    } else if (compute_capability >= VER_4VEC) {
2143        mmq_x  =  MMQ_X_Q5_1_PASCAL;
2144        mmq_y  =  MMQ_Y_Q5_1_PASCAL;
2145        nwarps = NWARPS_Q5_1_PASCAL;
2146    } else {
2147        GGML_ABORT("fatal error");
2148    }
2149
2150    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2151    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2152    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2153    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2154
2155    if (nrows_x % mmq_y == 0) {
2156        const bool need_check = false;
2157        /*
2158        DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
2159        the limit. To get the device limit, query
2160        info::device::max_work_group_size. Adjust the work-group size if needed.
2161        */
2162        {
2163            dpct::has_capability_or_fail(stream->get_device(),
2164                                         {sycl::aspect::fp16});
2165
2166            stream->submit([&](sycl::handler &cgh) {
2167                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2168                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2169                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2170                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2171                    cgh);
2172                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2173                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2174                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2175                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2176
2177                cgh.parallel_for(
2178                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2179                    [=](sycl::nd_item<3> item_ct1) {
2180                        mul_mat_q5_1<need_check>(
2181                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2182                            nrows_dst, item_ct1,
2183                            get_pointer(tile_x_ql_q5_1_acc_ct1),
2184                            get_pointer(tile_x_dm_q5_1_acc_ct1),
2185                            get_pointer(tile_y_qs_acc_ct1),
2186                            get_pointer(tile_y_ds_acc_ct1));
2187                    });
2188            });
2189        }
2190    } else {
2191        const bool need_check = true;
2192        /*
2193        DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
2194        the limit. To get the device limit, query
2195        info::device::max_work_group_size. Adjust the work-group size if needed.
2196        */
2197        {
2198            dpct::has_capability_or_fail(stream->get_device(),
2199                                         {sycl::aspect::fp16});
2200
2201            stream->submit([&](sycl::handler &cgh) {
2202                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2203                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2204                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2205                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2206                    cgh);
2207                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2208                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2209                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2210                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2211
2212                cgh.parallel_for(
2213                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2214                    [=](sycl::nd_item<3> item_ct1) {
2215                        mul_mat_q5_1<need_check>(
2216                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2217                            nrows_dst, item_ct1,
2218                            get_pointer(tile_x_ql_q5_1_acc_ct1),
2219                            get_pointer(tile_x_dm_q5_1_acc_ct1),
2220                            get_pointer(tile_y_qs_acc_ct1),
2221                            get_pointer(tile_y_ds_acc_ct1));
2222                    });
2223            });
2224        }
2225    }
2226}
2227catch (sycl::exception const &exc) {
2228  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2229            << ", line:" << __LINE__ << std::endl;
2230  std::exit(1);
2231}
2232
2233static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
2234                                        float *dst, const int ncols_x,
2235                                        const int nrows_x, const int ncols_y,
2236                                        const int nrows_y, const int nrows_dst,
2237                                        dpct::queue_ptr stream) try {
2238
2239    int id;
2240    SYCL_CHECK(
2241        CHECK_TRY_ERROR(id = get_current_device_id()));
2242    const int compute_capability = ggml_sycl_info().devices[id].cc;
2243
2244    int mmq_x, mmq_y, nwarps;
2245    if (compute_capability >= VER_GEN13) {
2246        mmq_x  =  MMQ_X_Q8_0_RDNA2;
2247        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
2248        nwarps = NWARPS_Q8_0_RDNA2;
2249    } else if (compute_capability >= VER_GEN12) {
2250        mmq_x  =  MMQ_X_Q8_0_RDNA1;
2251        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
2252        nwarps = NWARPS_Q8_0_RDNA1;
2253    } else if (compute_capability >= VER_GEN9) {
2254        mmq_x  =  MMQ_X_Q8_0_AMPERE;
2255        mmq_y  =  MMQ_Y_Q8_0_AMPERE;
2256        nwarps = NWARPS_Q8_0_AMPERE;
2257    } else if (compute_capability >= VER_4VEC) {
2258        mmq_x  =  MMQ_X_Q8_0_PASCAL;
2259        mmq_y  =  MMQ_Y_Q8_0_PASCAL;
2260        nwarps = NWARPS_Q8_0_PASCAL;
2261    } else {
2262        GGML_ABORT("fatal error");
2263    }
2264
2265    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2266    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2267    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2268    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2269
2270    if (nrows_x % mmq_y == 0) {
2271        const bool need_check = false;
2272        /*
2273        DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
2274        the limit. To get the device limit, query
2275        info::device::max_work_group_size. Adjust the work-group size if needed.
2276        */
2277        {
2278            dpct::has_capability_or_fail(stream->get_device(),
2279                                         {sycl::aspect::fp16});
2280
2281            stream->submit([&](sycl::handler &cgh) {
2282                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2283                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2284                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2285                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2286                    cgh);
2287                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2288                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2289                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2290                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2291
2292                cgh.parallel_for(
2293                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2294                    [=](sycl::nd_item<3> item_ct1) {
2295                        mul_mat_q8_0<need_check>(
2296                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2297                            nrows_dst, item_ct1,
2298                            get_pointer(tile_x_qs_q8_0_acc_ct1),
2299                            get_pointer(tile_x_d_q8_0_acc_ct1),
2300                            get_pointer(tile_y_qs_acc_ct1),
2301                            get_pointer(tile_y_ds_acc_ct1));
2302                    });
2303            });
2304        }
2305    } else {
2306        const bool need_check = true;
2307        /*
2308        DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
2309        the limit. To get the device limit, query
2310        info::device::max_work_group_size. Adjust the work-group size if needed.
2311        */
2312        {
2313            dpct::has_capability_or_fail(stream->get_device(),
2314                                         {sycl::aspect::fp16});
2315
2316            stream->submit([&](sycl::handler &cgh) {
2317                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2318                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2319                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2320                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2321                    cgh);
2322                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2323                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2324                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2325                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2326
2327                cgh.parallel_for(
2328                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2329                    [=](sycl::nd_item<3> item_ct1) {
2330                        mul_mat_q8_0<need_check>(
2331                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2332                            nrows_dst, item_ct1,
2333                            get_pointer(tile_x_qs_q8_0_acc_ct1),
2334                            get_pointer(tile_x_d_q8_0_acc_ct1),
2335                            get_pointer(tile_y_qs_acc_ct1),
2336                            get_pointer(tile_y_ds_acc_ct1));
2337                    });
2338            });
2339        }
2340    }
2341}
2342catch (sycl::exception const &exc) {
2343  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2344            << ", line:" << __LINE__ << std::endl;
2345  std::exit(1);
2346}
2347
2348static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
2349                                        float *dst, const int ncols_x,
2350                                        const int nrows_x, const int ncols_y,
2351                                        const int nrows_y, const int nrows_dst,
2352                                        dpct::queue_ptr stream) try {
2353
2354    int id;
2355    SYCL_CHECK(
2356        CHECK_TRY_ERROR(id = get_current_device_id()));
2357    const int compute_capability = ggml_sycl_info().devices[id].cc;
2358
2359    int mmq_x, mmq_y, nwarps;
2360    if (compute_capability >= VER_GEN13) {
2361        mmq_x  =  MMQ_X_Q2_K_RDNA2;
2362        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
2363        nwarps = NWARPS_Q2_K_RDNA2;
2364    } else if (compute_capability >= VER_GEN12) {
2365        mmq_x  =  MMQ_X_Q2_K_RDNA1;
2366        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
2367        nwarps = NWARPS_Q2_K_RDNA1;
2368    } else if (compute_capability >= VER_GEN9) {
2369        mmq_x  =  MMQ_X_Q2_K_AMPERE;
2370        mmq_y  =  MMQ_Y_Q2_K_AMPERE;
2371        nwarps = NWARPS_Q2_K_AMPERE;
2372    } else if (compute_capability >= VER_4VEC) {
2373        mmq_x  =  MMQ_X_Q2_K_PASCAL;
2374        mmq_y  =  MMQ_Y_Q2_K_PASCAL;
2375        nwarps = NWARPS_Q2_K_PASCAL;
2376    } else {
2377        GGML_ABORT("fatal error");
2378    }
2379
2380    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2381    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2382    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2383    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2384
2385    if (nrows_x % mmq_y == 0) {
2386        const bool need_check = false;
2387        /*
2388        DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
2389        the limit. To get the device limit, query
2390        info::device::max_work_group_size. Adjust the work-group size if needed.
2391        */
2392        {
2393            dpct::has_capability_or_fail(stream->get_device(),
2394                                         {sycl::aspect::fp16});
2395
2396            stream->submit([&](sycl::handler &cgh) {
2397                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2398                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2399                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2400                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2401                    cgh);
2402                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2403                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2404                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2405                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2406                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2407                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2408
2409                cgh.parallel_for(
2410                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2411                    [=](sycl::nd_item<3> item_ct1) {
2412                        mul_mat_q2_K<need_check>(
2413                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2414                            nrows_dst, item_ct1,
2415                            get_pointer(tile_x_ql_q2_K_acc_ct1),
2416                            get_pointer(tile_x_dm_q2_K_acc_ct1),
2417                            get_pointer(tile_x_sc_q2_K_acc_ct1),
2418                            get_pointer(tile_y_qs_acc_ct1),
2419                            get_pointer(tile_y_ds_acc_ct1));
2420                    });
2421            });
2422        }
2423    } else {
2424        const bool need_check = true;
2425        /*
2426        DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
2427        the limit. To get the device limit, query
2428        info::device::max_work_group_size. Adjust the work-group size if needed.
2429        */
2430        {
2431            dpct::has_capability_or_fail(stream->get_device(),
2432                                         {sycl::aspect::fp16});
2433
2434            stream->submit([&](sycl::handler &cgh) {
2435                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2436                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2437                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2438                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2439                    cgh);
2440                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2441                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2442                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2443                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2444                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2445                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2446
2447                cgh.parallel_for(
2448                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2449                    [=](sycl::nd_item<3> item_ct1) {
2450                        mul_mat_q2_K<need_check>(
2451                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2452                            nrows_dst, item_ct1,
2453                            get_pointer(tile_x_ql_q2_K_acc_ct1),
2454                            get_pointer(tile_x_dm_q2_K_acc_ct1),
2455                            get_pointer(tile_x_sc_q2_K_acc_ct1),
2456                            get_pointer(tile_y_qs_acc_ct1),
2457                            get_pointer(tile_y_ds_acc_ct1));
2458                    });
2459            });
2460        }
2461    }
2462}
2463catch (sycl::exception const &exc) {
2464  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2465            << ", line:" << __LINE__ << std::endl;
2466  std::exit(1);
2467}
2468
2469static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
2470                                        float *dst, const int ncols_x,
2471                                        const int nrows_x, const int ncols_y,
2472                                        const int nrows_y, const int nrows_dst,
2473                                        dpct::queue_ptr stream) try {
2474
2475#if QK_K == 256
2476
2477    int id;
2478    SYCL_CHECK(
2479        CHECK_TRY_ERROR(id = get_current_device_id()));
2480    const int compute_capability = ggml_sycl_info().devices[id].cc;
2481
2482    int mmq_x, mmq_y, nwarps;
2483    if (compute_capability >= VER_GEN13) {
2484        mmq_x  =  MMQ_X_Q3_K_RDNA2;
2485        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
2486        nwarps = NWARPS_Q3_K_RDNA2;
2487    } else if (compute_capability >= VER_GEN12) {
2488        mmq_x  =  MMQ_X_Q3_K_RDNA1;
2489        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
2490        nwarps = NWARPS_Q3_K_RDNA1;
2491    } else if (compute_capability >= VER_GEN9) {
2492        mmq_x  =  MMQ_X_Q3_K_AMPERE;
2493        mmq_y  =  MMQ_Y_Q3_K_AMPERE;
2494        nwarps = NWARPS_Q3_K_AMPERE;
2495    } else if (compute_capability >= VER_4VEC) {
2496        mmq_x  =  MMQ_X_Q3_K_PASCAL;
2497        mmq_y  =  MMQ_Y_Q3_K_PASCAL;
2498        nwarps = NWARPS_Q3_K_PASCAL;
2499    } else {
2500        GGML_ABORT("fatal error");
2501    }
2502
2503    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2504    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2505    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2506    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2507
2508    if (nrows_x % mmq_y == 0) {
2509        const bool need_check = false;
2510        /*
2511        DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
2512        the limit. To get the device limit, query
2513        info::device::max_work_group_size. Adjust the work-group size if needed.
2514        */
2515        {
2516            dpct::has_capability_or_fail(stream->get_device(),
2517                                         {sycl::aspect::fp16});
2518
2519            stream->submit([&](sycl::handler &cgh) {
2520                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2521                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2522                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2523                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2524                    cgh);
2525                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2526                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2527                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2528                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2529                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2530                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2531                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2532                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2533
2534                cgh.parallel_for(
2535                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2536                    [=](sycl::nd_item<3> item_ct1) {
2537                        mul_mat_q3_K<need_check>(
2538                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2539                            nrows_dst, item_ct1,
2540                            get_pointer(tile_x_ql_q3_K_acc_ct1),
2541                            get_pointer(tile_x_dm_q3_K_acc_ct1),
2542                            get_pointer(tile_x_qh_q3_K_acc_ct1),
2543                            get_pointer(tile_x_sc_q3_K_acc_ct1),
2544                            get_pointer(tile_y_qs_acc_ct1),
2545                            get_pointer(tile_y_ds_acc_ct1));
2546                    });
2547            });
2548        }
2549    } else {
2550        const bool need_check = true;
2551        /*
2552        DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
2553        the limit. To get the device limit, query
2554        info::device::max_work_group_size. Adjust the work-group size if needed.
2555        */
2556        {
2557            dpct::has_capability_or_fail(stream->get_device(),
2558                                         {sycl::aspect::fp16});
2559
2560            stream->submit([&](sycl::handler &cgh) {
2561                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2562                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2563                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2564                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2565                    cgh);
2566                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2567                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2568                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2569                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2570                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2571                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2572                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2573                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2574
2575                cgh.parallel_for(
2576                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2577                    [=](sycl::nd_item<3> item_ct1) {
2578                        mul_mat_q3_K<need_check>(
2579                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2580                            nrows_dst, item_ct1,
2581                            get_pointer(tile_x_ql_q3_K_acc_ct1),
2582                            get_pointer(tile_x_dm_q3_K_acc_ct1),
2583                            get_pointer(tile_x_qh_q3_K_acc_ct1),
2584                            get_pointer(tile_x_sc_q3_K_acc_ct1),
2585                            get_pointer(tile_y_qs_acc_ct1),
2586                            get_pointer(tile_y_ds_acc_ct1));
2587                    });
2588            });
2589        }
2590    }
2591#endif
2592}
2593catch (sycl::exception const &exc) {
2594  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2595            << ", line:" << __LINE__ << std::endl;
2596  std::exit(1);
2597}
2598
2599static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
2600                                        float *dst, const int ncols_x,
2601                                        const int nrows_x, const int ncols_y,
2602                                        const int nrows_y, const int nrows_dst,
2603                                        dpct::queue_ptr stream) try {
2604
2605    int id;
2606    SYCL_CHECK(
2607        CHECK_TRY_ERROR(id = get_current_device_id()));
2608    const int compute_capability = ggml_sycl_info().devices[id].cc;
2609
2610    int mmq_x, mmq_y, nwarps;
2611    if (compute_capability >= VER_GEN13) {
2612        mmq_x  =  MMQ_X_Q4_K_RDNA2;
2613        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
2614        nwarps = NWARPS_Q4_K_RDNA2;
2615    } else if (compute_capability >= VER_GEN12) {
2616        mmq_x  =  MMQ_X_Q4_K_RDNA1;
2617        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
2618        nwarps = NWARPS_Q4_K_RDNA1;
2619    } else if (compute_capability >= VER_GEN9) {
2620        mmq_x  =  MMQ_X_Q4_K_AMPERE;
2621        mmq_y  =  MMQ_Y_Q4_K_AMPERE;
2622        nwarps = NWARPS_Q4_K_AMPERE;
2623    } else if (compute_capability >= VER_4VEC) {
2624        mmq_x  =  MMQ_X_Q4_K_PASCAL;
2625        mmq_y  =  MMQ_Y_Q4_K_PASCAL;
2626        nwarps = NWARPS_Q4_K_PASCAL;
2627    } else {
2628        GGML_ABORT("fatal error");
2629    }
2630
2631    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2632    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2633    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2634    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2635
2636    if (nrows_x % mmq_y == 0) {
2637        const bool need_check = false;
2638        /*
2639        DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
2640        the limit. To get the device limit, query
2641        info::device::max_work_group_size. Adjust the work-group size if needed.
2642        */
2643        {
2644            dpct::has_capability_or_fail(stream->get_device(),
2645                                         {sycl::aspect::fp16});
2646
2647            stream->submit([&](sycl::handler &cgh) {
2648                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2649                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2650                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2651                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2652                    cgh);
2653                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2654                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2655                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2656                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2657                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2658                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2659
2660                cgh.parallel_for(
2661                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2662                    [=](sycl::nd_item<3> item_ct1) {
2663                        mul_mat_q4_K<need_check>(
2664                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2665                            nrows_dst, item_ct1,
2666                            get_pointer(tile_x_ql_q4_K_acc_ct1),
2667                            get_pointer(tile_x_dm_q4_K_acc_ct1),
2668                            get_pointer(tile_x_sc_q4_K_acc_ct1),
2669                            get_pointer(tile_y_qs_acc_ct1),
2670                            get_pointer(tile_y_ds_acc_ct1));
2671                    });
2672            });
2673        }
2674    } else {
2675        const bool need_check = true;
2676        /*
2677        DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
2678        the limit. To get the device limit, query
2679        info::device::max_work_group_size. Adjust the work-group size if needed.
2680        */
2681        {
2682            dpct::has_capability_or_fail(stream->get_device(),
2683                                         {sycl::aspect::fp16});
2684
2685            stream->submit([&](sycl::handler &cgh) {
2686                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2687                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2688                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2689                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2690                    cgh);
2691                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2692                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2693                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2694                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2695                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2696                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2697
2698                cgh.parallel_for(
2699                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2700                    [=](sycl::nd_item<3> item_ct1) {
2701                        mul_mat_q4_K<need_check>(
2702                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2703                            nrows_dst, item_ct1,
2704                            get_pointer(tile_x_ql_q4_K_acc_ct1),
2705                            get_pointer(tile_x_dm_q4_K_acc_ct1),
2706                            get_pointer(tile_x_sc_q4_K_acc_ct1),
2707                            get_pointer(tile_y_qs_acc_ct1),
2708                            get_pointer(tile_y_ds_acc_ct1));
2709                    });
2710            });
2711        }
2712    }
2713}
2714catch (sycl::exception const &exc) {
2715  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2716            << ", line:" << __LINE__ << std::endl;
2717  std::exit(1);
2718}
2719
2720static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
2721                                        float *dst, const int ncols_x,
2722                                        const int nrows_x, const int ncols_y,
2723                                        const int nrows_y, const int nrows_dst,
2724                                        dpct::queue_ptr stream) try {
2725
2726    int id;
2727    SYCL_CHECK(
2728        CHECK_TRY_ERROR(id = get_current_device_id()));
2729    const int compute_capability = ggml_sycl_info().devices[id].cc;
2730
2731    int mmq_x, mmq_y, nwarps;
2732    if (compute_capability >= VER_GEN13) {
2733        mmq_x  =  MMQ_X_Q5_K_RDNA2;
2734        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
2735        nwarps = NWARPS_Q5_K_RDNA2;
2736    } else if (compute_capability >= VER_GEN12) {
2737        mmq_x  =  MMQ_X_Q5_K_RDNA1;
2738        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
2739        nwarps = NWARPS_Q5_K_RDNA1;
2740    } else if (compute_capability >= VER_GEN9) {
2741        mmq_x  =  MMQ_X_Q5_K_AMPERE;
2742        mmq_y  =  MMQ_Y_Q5_K_AMPERE;
2743        nwarps = NWARPS_Q5_K_AMPERE;
2744    } else if (compute_capability >= VER_4VEC) {
2745        mmq_x  =  MMQ_X_Q5_K_PASCAL;
2746        mmq_y  =  MMQ_Y_Q5_K_PASCAL;
2747        nwarps = NWARPS_Q5_K_PASCAL;
2748    } else {
2749        GGML_ABORT("fatal error");
2750    }
2751
2752    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2753    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2754    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2755    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2756
2757    if (nrows_x % mmq_y == 0) {
2758        const bool need_check = false;
2759        /*
2760        DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
2761        the limit. To get the device limit, query
2762        info::device::max_work_group_size. Adjust the work-group size if needed.
2763        */
2764        {
2765            dpct::has_capability_or_fail(stream->get_device(),
2766                                         {sycl::aspect::fp16});
2767
2768            stream->submit([&](sycl::handler &cgh) {
2769                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2770                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2771                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2772                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2773                    cgh);
2774                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2775                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2776                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2777                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2778                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2779                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2780
2781                cgh.parallel_for(
2782                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2783                    [=](sycl::nd_item<3> item_ct1) {
2784                        mul_mat_q5_K<need_check>(
2785                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2786                            nrows_dst, item_ct1,
2787                            get_pointer(tile_x_ql_q5_K_acc_ct1),
2788                            get_pointer(tile_x_dm_q5_K_acc_ct1),
2789                            get_pointer(tile_x_sc_q5_K_acc_ct1),
2790                            get_pointer(tile_y_qs_acc_ct1),
2791                            get_pointer(tile_y_ds_acc_ct1));
2792                    });
2793            });
2794        }
2795    } else {
2796        const bool need_check = true;
2797        /*
2798        DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
2799        the limit. To get the device limit, query
2800        info::device::max_work_group_size. Adjust the work-group size if needed.
2801        */
2802        {
2803            dpct::has_capability_or_fail(stream->get_device(),
2804                                         {sycl::aspect::fp16});
2805
2806            stream->submit([&](sycl::handler &cgh) {
2807                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2808                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2809                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2810                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2811                    cgh);
2812                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2813                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2814                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2815                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2816                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2817                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2818
2819                cgh.parallel_for(
2820                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2821                    [=](sycl::nd_item<3> item_ct1) {
2822                        mul_mat_q5_K<need_check>(
2823                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2824                            nrows_dst, item_ct1,
2825                            get_pointer(tile_x_ql_q5_K_acc_ct1),
2826                            get_pointer(tile_x_dm_q5_K_acc_ct1),
2827                            get_pointer(tile_x_sc_q5_K_acc_ct1),
2828                            get_pointer(tile_y_qs_acc_ct1),
2829                            get_pointer(tile_y_ds_acc_ct1));
2830                    });
2831            });
2832        }
2833    }
2834}
2835catch (sycl::exception const &exc) {
2836  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2837            << ", line:" << __LINE__ << std::endl;
2838  std::exit(1);
2839}
2840
2841static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
2842                                        float *dst, const int ncols_x,
2843                                        const int nrows_x, const int ncols_y,
2844                                        const int nrows_y, const int nrows_dst,
2845                                        dpct::queue_ptr stream) try {
2846
2847    int id;
2848    SYCL_CHECK(
2849        CHECK_TRY_ERROR(id = get_current_device_id()));
2850    const int compute_capability = ggml_sycl_info().devices[id].cc;
2851
2852    int mmq_x, mmq_y, nwarps;
2853    if (compute_capability >= VER_GEN13) {
2854        mmq_x  =  MMQ_X_Q6_K_RDNA2;
2855        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
2856        nwarps = NWARPS_Q6_K_RDNA2;
2857    } else if (compute_capability >= VER_GEN12) {
2858        mmq_x  =  MMQ_X_Q6_K_RDNA1;
2859        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
2860        nwarps = NWARPS_Q6_K_RDNA1;
2861    } else if (compute_capability >= VER_GEN9) {
2862        mmq_x  =  MMQ_X_Q6_K_AMPERE;
2863        mmq_y  =  MMQ_Y_Q6_K_AMPERE;
2864        nwarps = NWARPS_Q6_K_AMPERE;
2865    } else if (compute_capability >= VER_4VEC) {
2866        mmq_x  =  MMQ_X_Q6_K_PASCAL;
2867        mmq_y  =  MMQ_Y_Q6_K_PASCAL;
2868        nwarps = NWARPS_Q6_K_PASCAL;
2869    } else {
2870        GGML_ABORT("fatal error");
2871    }
2872
2873    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2874    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2875    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2876    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2877
2878    if (nrows_x % mmq_y == 0) {
2879        const bool need_check = false;
2880        /*
2881        DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
2882        the limit. To get the device limit, query
2883        info::device::max_work_group_size. Adjust the work-group size if needed.
2884        */
2885        {
2886            dpct::has_capability_or_fail(stream->get_device(),
2887                                         {sycl::aspect::fp16});
2888
2889            stream->submit([&](sycl::handler &cgh) {
2890                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2891                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2892                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2893                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2894                    cgh);
2895                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2896                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2897                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2898                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2899                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2900                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2901
2902                cgh.parallel_for(
2903                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2904                    [=](sycl::nd_item<3> item_ct1) {
2905                        mul_mat_q6_K<need_check>(
2906                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2907                            nrows_dst, item_ct1,
2908                            get_pointer(tile_x_ql_acc_ct1),
2909                            get_pointer(tile_x_dm_acc_ct1),
2910                            get_pointer(tile_x_sc_acc_ct1),
2911                            get_pointer(tile_y_qs_acc_ct1),
2912                            get_pointer(tile_y_ds_acc_ct1));
2913                    });
2914            });
2915        }
2916    } else {
2917        const bool need_check = true;
2918        /*
2919        DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
2920        the limit. To get the device limit, query
2921        info::device::max_work_group_size. Adjust the work-group size if needed.
2922        */
2923        {
2924            dpct::has_capability_or_fail(stream->get_device(),
2925                                         {sycl::aspect::fp16});
2926
2927            stream->submit([&](sycl::handler &cgh) {
2928                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2929                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2930                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2931                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2932                    cgh);
2933                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2934                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2935                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2936                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2937                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2938                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2939
2940                cgh.parallel_for(
2941                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
2942                    [=](sycl::nd_item<3> item_ct1) {
2943                        mul_mat_q6_K<need_check>(
2944                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2945                            nrows_dst, item_ct1,
2946                            get_pointer(tile_x_ql_acc_ct1),
2947                            get_pointer(tile_x_dm_acc_ct1),
2948                            get_pointer(tile_x_sc_acc_ct1),
2949                            get_pointer(tile_y_qs_acc_ct1),
2950                            get_pointer(tile_y_ds_acc_ct1));
2951                    });
2952            });
2953        }
2954    }
2955}
2956catch (sycl::exception const &exc) {
2957  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2958            << ", line:" << __LINE__ << std::endl;
2959  std::exit(1);
2960}
2961
2962void ggml_sycl_op_mul_mat_q(
2963    ggml_backend_sycl_context & ctx,
2964    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2965    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2966    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2967    const int64_t src1_ncols, const int64_t src1_padded_row_size,
2968    const dpct::queue_ptr &stream) try {
2969
2970    const int64_t ne00 = src0->ne[0];
2971
2972    const int64_t ne10 = src1->ne[0];
2973    GGML_ASSERT(ne10 % QK8_1 == 0);
2974
2975    const int64_t ne0 = dst->ne[0];
2976
2977    const int64_t row_diff = row_high - row_low;
2978
2979    int device_id;
2980    SYCL_CHECK(
2981        CHECK_TRY_ERROR(device_id = get_current_device_id()));
2982
2983    // the main device has a larger memory buffer to hold the results from all GPUs
2984    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
2985    const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
2986
2987    switch (src0->type) {
2988        case GGML_TYPE_Q4_0:
2989            ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2990            break;
2991        case GGML_TYPE_Q4_1:
2992            ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2993            break;
2994        case GGML_TYPE_Q5_0:
2995            ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2996            break;
2997        case GGML_TYPE_Q5_1:
2998            ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2999            break;
3000        case GGML_TYPE_Q8_0:
3001            ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3002            break;
3003        case GGML_TYPE_Q2_K:
3004            ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3005            break;
3006        case GGML_TYPE_Q3_K:
3007            ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3008            break;
3009        case GGML_TYPE_Q4_K:
3010            ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3011            break;
3012        case GGML_TYPE_Q5_K:
3013            ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3014            break;
3015        case GGML_TYPE_Q6_K:
3016            ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3017            break;
3018        default:
3019            GGML_ABORT("fatal error");
3020    }
3021
3022    GGML_UNUSED(src1);
3023    GGML_UNUSED(dst);
3024    GGML_UNUSED(src1_ddf_i);
3025}
3026catch (sycl::exception const &exc) {
3027  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3028            << ", line:" << __LINE__ << std::endl;
3029  std::exit(1);
3030}