1#include "convert.hpp"
   2#include "dmmv.hpp"
   3#include "dequantize.hpp"
   4#include "presets.hpp"
   5
   6static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
   7    const sycl::half *x = (const sycl::half *)vx;
   8
   9    // automatic half -> float type cast if dfloat == float
  10    v.x() = x[ib + iqs + 0];
  11    v.y() = x[ib + iqs + 1];
  12}
  13
  14static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
  15    const float * x = (const float *) vx;
  16
  17    // automatic half -> float type cast if dfloat == float
  18    v.x() = x[ib + iqs + 0];
  19    v.y() = x[ib + iqs + 1];
  20}
  21
  22template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  23static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
  24                                   const sycl::nd_item<3> &item_ct1) {
  25    // qk = quantized weights per x block
  26    // qr = number of quantized weights per data value in x block
  27    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  28                    item_ct1.get_local_id(1);
  29
  30    if (row >= nrows) {
  31        return;
  32    }
  33
  34    const int tid = item_ct1.get_local_id(2);
  35
  36    const int iter_stride = 2*GGML_SYCL_DMMV_X;
  37    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  38    const int y_offset = qr == 1 ? 1 : qk/2;
  39
  40// partial sum for each thread
  41#ifdef GGML_SYCL_F16
  42    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
  43#else
  44    float tmp = 0.0f;
  45#endif // GGML_SYCL_F16
  46
  47    for (int i = 0; i < ncols; i += iter_stride) {
  48        const int col = i + vals_per_iter*tid;
  49        const int ib = (row*ncols + col)/qk; // x block index
  50        const int iqs = (col%qk)/qr; // x quant index
  51        const int iybs = col - col%qk; // y block start index
  52
  53// processing >2 values per i iter is faster for fast GPUs
  54#pragma unroll
  55        for (int j = 0; j < vals_per_iter; j += 2) {
  56            // process 2 vals per j iter
  57
  58            // dequantize
  59            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  60            dfloat2 v;
  61            dequantize_kernel(vx, ib, iqs + j/qr, v);
  62
  63            // matrix multiplication
  64            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  65#ifdef GGML_SYCL_F16
  66            dfloat2 t1{y[iybs + iqs + j / qr + 0],
  67                        y[iybs + iqs + j / qr + y_offset]};
  68
  69            tmp += v * t1;
  70#else
  71            tmp += v.x() * y[iybs + iqs + j / qr + 0];
  72            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
  73#endif // GGML_SYCL_F16
  74        }
  75    }
  76
  77    // sum up partial sums and write back result
  78    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
  79    for (int mask = mask_start; mask > 0; mask >>= 1) {
  80        tmp +=
  81            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
  82    }
  83
  84    if (tid == 0) {
  85#ifdef GGML_SYCL_F16
  86        dst[row] = tmp.x() + tmp.y();
  87#else
  88        dst[row] = tmp;
  89#endif // GGML_SYCL_F16
  90    }
  91}
  92
  93template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
  94static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
  95                                   const sycl::nd_item<3> &item_ct1) {
  96    // qk = quantized weights per x block
  97    // qr = number of quantized weights per data value in x block
  98    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  99                    item_ct1.get_local_id(1);
 100
 101    if (row >= nrows) {
 102        return;
 103    }
 104
 105    const int tid = item_ct1.get_local_id(2);
 106
 107
 108    const int ncols_left = ncols % (QK4_0*WARP_SIZE);
 109    const int ncols_align = ncols - ncols_left;
 110    const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
 111    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
 112    const int y_offset = qr == 1 ? 1 : qk/2;
 113
 114// partial sum for each thread
 115#ifdef GGML_SYCL_F16
 116    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
 117#else
 118    float tmp = 0.0f;
 119#endif // GGML_SYCL_F16
 120    const char *d_ptr = (const char*)vx+ncols*nrows/2;
 121    int i=0;
 122    for (i = 0; i < ncols_align; i += iter_stride) {
 123        const int col = i + vals_per_iter*tid;
 124        const int ib = (row*ncols + col)/qk; // x block index
 125        const int iqs = (col%qk)/qr; // x quant index
 126        const int iybs = col - col%qk; // y block start index
 127
 128// processing >2 values per i iter is faster for fast GPUs
 129#pragma unroll
 130        for (int j = 0; j < vals_per_iter; j += 2) {
 131            // process 2 vals per j iter
 132
 133            // dequantize
 134            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
 135            dfloat2 v;
 136            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
 137
 138            // matrix multiplication
 139            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
 140#ifdef GGML_SYCL_F16
 141            dfloat2 t1{y[iybs + iqs + j / qr + 0],
 142                        y[iybs + iqs + j / qr + y_offset]};
 143
 144            tmp += v * t1;
 145#else
 146            tmp += v.x() * y[iybs + iqs + j / qr + 0];
 147            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
 148#endif // GGML_SYCL_F16
 149        }
 150    }
 151
 152    for (; i < ncols; i += iter_stride) {
 153        if (tid>=ncols_left/QK4_0) continue;
 154        const int col = i + vals_per_iter*tid;
 155        const int ib = (row*ncols + col)/qk; // x block index
 156        const int iqs = (col%qk)/qr; // x quant index
 157        const int iybs = col - col%qk; // y block start index
 158
 159// processing >2 values per i iter is faster for fast GPUs
 160#pragma unroll
 161        for (int j = 0; j < vals_per_iter; j += 2) {
 162            // process 2 vals per j iter
 163
 164            // dequantize
 165            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
 166            dfloat2 v;
 167            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
 168
 169            // matrix multiplication
 170            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
 171#ifdef GGML_SYCL_F16
 172            dfloat2 t1{y[iybs + iqs + j / qr + 0],
 173                        y[iybs + iqs + j / qr + y_offset]};
 174
 175            tmp += v * t1;
 176#else
 177            tmp += v.x() * y[iybs + iqs + j / qr + 0];
 178            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
 179#endif // GGML_SYCL_F16
 180        }
 181    }
 182
 183    // sum up partial sums and write back result
 184    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
 185    for (int mask = mask_start; mask > 0; mask >>= 1) {
 186        tmp +=
 187            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 188    }
 189
 190    if (tid == 0) {
 191#ifdef GGML_SYCL_F16
 192        dst[row] = tmp.x() + tmp.y();
 193#else
 194        dst[row] = tmp;
 195#endif // GGML_SYCL_F16
 196    }
 197}
 198
 199static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
 200                                         float *dst, const int ncols,
 201                                         const int nrows,
 202                                         dpct::queue_ptr stream) {
 203    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 204    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 205    const sycl::range<3> block_nums(1, 1, block_num_y);
 206    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 207    {
 208        dpct::has_capability_or_fail(stream->get_device(),
 209                                     {sycl::aspect::fp16});
 210
 211        stream->parallel_for(
 212            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 213            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 214                dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
 215                                                          nrows, item_ct1);
 216            });
 217    }
 218}
 219
 220/*
 221DPCT1110:4: The total declared local variable size in device function
 222dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
 223pressure. Consult with your hardware vendor to find the total register size
 224available and adjust the code, or use smaller sub-group size to avoid high
 225register pressure.
 226*/
 227static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
 228                                        const float *__restrict__ yy,
 229                                        float *__restrict__ dst,
 230                                        const int ncols, int nrows,
 231                                        const sycl::nd_item<3> &item_ct1) {
 232
 233    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 234
 235    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 236                    item_ct1.get_local_id(1);
 237    if (row > nrows) return;
 238
 239    const int num_blocks_per_row = ncols / QK_K;
 240    const int ib0 = row*num_blocks_per_row;
 241
 242    const block_q2_K * x = (const block_q2_K *)vx + ib0;
 243
 244    float tmp = 0; // partial sum for thread in warp
 245
 246#if QK_K == 256
 247    const int tid =
 248        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
 249    const int ix =
 250        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
 251
 252    const int step = 16/K_QUANTS_PER_ITERATION;
 253
 254    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
 255    const int in = tid - step*im;                        // 0...15 or 0...7
 256
 257    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
 258    const int q_offset = 32*im + l0;
 259    const int s_offset = 8*im;
 260    const int y_offset = 128*im + l0;
 261
 262    uint32_t aux[4];
 263    const uint8_t * d = (const uint8_t *)aux;
 264    const uint8_t * m = (const uint8_t *)(aux + 2);
 265
 266    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 267
 268        const float   * y = yy + i * QK_K + y_offset;
 269        const uint8_t * q = x[i].qs + q_offset;
 270
 271        const float dall = x[i].dm[0];
 272        const float dmin = x[i].dm[1];
 273
 274        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
 275        aux[0] = a[0] & 0x0f0f0f0f;
 276        aux[1] = a[1] & 0x0f0f0f0f;
 277        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
 278        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
 279
 280        float sum1 = 0, sum2 = 0;
 281        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
 282            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
 283                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
 284                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
 285                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
 286                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
 287                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
 288                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
 289                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
 290            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
 291                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 292
 293        }
 294        tmp += dall * sum1 - dmin * sum2;
 295
 296    }
 297#else
 298    const int tid = item_ct1.get_local_id(2) /
 299                    (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
 300    const int ix = item_ct1.get_local_id(2) %
 301                   (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
 302    const int offset = tid * K_QUANTS_PER_ITERATION;
 303
 304    uint32_t uaux[2];
 305    const uint8_t * d = (const uint8_t *)uaux;
 306
 307
 308    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
 309
 310        const float   * y = yy + i * QK_K + offset;
 311        const uint8_t * q = x[i].qs + offset;
 312        const uint32_t * s = (const uint32_t *)x[i].scales;
 313
 314        uaux[0] = s[0] & 0x0f0f0f0f;
 315        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
 316
 317        const sycl::float2 dall =
 318            x[i].dm.convert<float, sycl::rounding_mode::automatic>();
 319
 320        float sum1 = 0, sum2 = 0;
 321        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
 322            const uint8_t ql = q[l];
 323            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
 324                  + y[l+16] * d[1] * ((ql >> 2) & 3)
 325                  + y[l+32] * d[2] * ((ql >> 4) & 3)
 326                  + y[l+48] * d[3] * ((ql >> 6) & 3);
 327            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
 328        }
 329        tmp += dall.x() * sum1 - dall.y() * sum2;
 330    }
 331
 332#endif
 333
 334    // sum up partial sums and write back result
 335#pragma unroll
 336    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
 337        tmp +=
 338            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 339    }
 340
 341    if (item_ct1.get_local_id(2) == 0) {
 342        dst[row] = tmp;
 343    }
 344}
 345
 346/*
 347DPCT1110:5: The total declared local variable size in device function
 348dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register
 349pressure. Consult with your hardware vendor to find the total register size
 350available and adjust the code, or use smaller sub-group size to avoid high
 351register pressure.
 352*/
 353static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
 354                                        const float *__restrict__ yy,
 355                                        float *__restrict__ dst,
 356                                        const int ncols, int nrows,
 357                                        const sycl::nd_item<3> &item_ct1) {
 358
 359    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 360                    item_ct1.get_local_id(1);
 361    if (row > nrows) return;
 362
 363    const int num_blocks_per_row = ncols / QK_K;
 364    const int ib0 = row*num_blocks_per_row;
 365
 366    const block_q3_K * x = (const block_q3_K *)vx + ib0;
 367
 368    float tmp = 0; // partial sum for thread in warp
 369
 370#if QK_K == 256
 371
 372    const uint16_t kmask1 = 0x0303;
 373    const uint16_t kmask2 = 0x0f0f;
 374
 375    const int tid =
 376        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
 377    const int ix =
 378        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
 379
 380    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
 381    const int step = 16/K_QUANTS_PER_ITERATION;
 382    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
 383    const int in = tid - step*im;                        // 0....15 or 0...7
 384
 385    const uint8_t m = 1 << (4*im);
 386
 387    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
 388    const int q_offset =  32*im + l0;
 389    const int y_offset = 128*im + l0;
 390
 391    uint16_t utmp[4];
 392    const int8_t * s = (const int8_t *)utmp;
 393
 394    const uint16_t s_shift = 4*im;
 395
 396    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 397
 398        const float   * y  = yy + i * QK_K + y_offset;
 399        const uint8_t * q = x[i].qs + q_offset;
 400        const uint8_t * h = x[i].hmask + l0;
 401
 402        const uint16_t * a = (const uint16_t *)x[i].scales;
 403        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
 404        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
 405        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
 406        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
 407
 408        const float d = x[i].d;
 409
 410        float sum = 0;
 411        for (int l = 0; l < n; ++l) {
 412            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
 413                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
 414                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
 415                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
 416            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
 417                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
 418                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
 419                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
 420        }
 421        tmp += d * sum;
 422
 423    }
 424#else
 425
 426    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
 427    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
 428    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
 429    const int in = offset/8;                                 // 0 or 1
 430    const int im = offset%8;                                 // 0...7
 431
 432    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
 433
 434        const float   * y = yy + i * QK_K + offset;
 435        const uint8_t * q = x[i].qs + offset;
 436        const uint8_t * s = x[i].scales;
 437
 438        const float dall = (float)x[i].d;
 439
 440        float sum = 0;
 441        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
 442            const uint8_t hl = x[i].hmask[im+l] >> in;
 443            const uint8_t ql = q[l];
 444            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
 445                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
 446                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
 447                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
 448        }
 449        tmp += sum;
 450    }
 451#endif
 452
 453    // sum up partial sums and write back result
 454#pragma unroll
 455    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
 456        tmp +=
 457            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 458    }
 459
 460    if (item_ct1.get_local_id(2) == 0) {
 461        dst[row] = tmp;
 462    }
 463}
 464
 465/*
 466DPCT1110:6: The total declared local variable size in device function
 467dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
 468pressure. Consult with your hardware vendor to find the total register size
 469available and adjust the code, or use smaller sub-group size to avoid high
 470register pressure.
 471*/
 472static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
 473                                        const float *__restrict__ yy,
 474                                        float *__restrict__ dst,
 475                                        const int ncols, int nrows,
 476                                        const sycl::nd_item<3> &item_ct1) {
 477
 478    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 479                    item_ct1.get_local_id(1);
 480    if (row > nrows) return;
 481    const int num_blocks_per_row = ncols / QK_K;
 482    const int ib0 = row*num_blocks_per_row;
 483
 484    const block_q4_K * x = (const block_q4_K *)vx + ib0;
 485
 486#if QK_K == 256
 487    const uint16_t kmask1 = 0x3f3f;
 488    const uint16_t kmask2 = 0x0f0f;
 489    const uint16_t kmask3 = 0xc0c0;
 490
 491    const int tid =
 492        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
 493    const int ix =
 494        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
 495
 496    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
 497
 498    const int il  = tid/step;                            // 0...3
 499    const int ir  = tid - step*il;                       // 0...7 or 0...3
 500    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
 501
 502    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
 503    const int in = il%2;
 504
 505    const int l0 = n*(2*ir + in);
 506    const int q_offset = 32*im + l0;
 507    const int y_offset = 64*im + l0;
 508
 509    uint16_t aux[4];
 510    const uint8_t * sc = (const uint8_t *)aux;
 511
 512#if K_QUANTS_PER_ITERATION == 2
 513    uint32_t q32[4];
 514    const uint8_t * q4 = (const uint8_t *)q32;
 515#else
 516    uint16_t q16[4];
 517    const uint8_t * q4 = (const uint8_t *)q16;
 518#endif
 519
 520    float tmp = 0; // partial sum for thread in warp
 521
 522    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 523
 524        const float   * y1 = yy + i*QK_K + y_offset;
 525        const float   * y2 = y1 + 128;
 526
 527        const float dall = x[i].dm[0];
 528        const float dmin = x[i].dm[1];
 529
 530        const uint16_t * a = (const uint16_t *)x[i].scales;
 531        aux[0] = a[im+0] & kmask1;
 532        aux[1] = a[im+2] & kmask1;
 533        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
 534        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
 535
 536#if K_QUANTS_PER_ITERATION == 2
 537        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
 538        const uint32_t * q2 = q1 + 16;
 539
 540        q32[0] = q1[0] & 0x0f0f0f0f;
 541        q32[1] = q1[0] & 0xf0f0f0f0;
 542        q32[2] = q2[0] & 0x0f0f0f0f;
 543        q32[3] = q2[0] & 0xf0f0f0f0;
 544
 545        sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
 546        float smin = 0;
 547        for (int l = 0; l < 4; ++l) {
 548            s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
 549            s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
 550            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
 551        }
 552        tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
 553                       s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
 554               dmin * smin;
 555#else
 556        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
 557        const uint16_t * q2 = q1 + 32;
 558
 559        q16[0] = q1[0] & 0x0f0f;
 560        q16[1] = q1[0] & 0xf0f0;
 561        q16[2] = q2[0] & 0x0f0f;
 562        q16[3] = q2[0] & 0xf0f0;
 563
 564        float4 s = {0.f, 0.f, 0.f, 0.f};
 565        float smin = 0;
 566        for (int l = 0; l < 2; ++l) {
 567            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
 568            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
 569            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
 570        }
 571        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
 572#endif
 573
 574    }
 575#else
 576    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15
 577    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
 578
 579    const int step = tid * K_QUANTS_PER_ITERATION;
 580
 581    uint16_t aux16[2];
 582    const uint8_t * s = (const uint8_t *)aux16;
 583
 584    float tmp = 0;
 585
 586    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
 587        const uint8_t * q = x[i].qs + step;
 588        const float   * y = yy + i*QK_K + step;
 589        const uint16_t * a = (const uint16_t *)x[i].scales;
 590        aux16[0] = a[0] & 0x0f0f;
 591        aux16[1] = (a[0] >> 4) & 0x0f0f;
 592        const float d = (float)x[i].dm[0];
 593        const float m = (float)x[i].dm[1];
 594        float sum = 0.f;
 595        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
 596            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
 597                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
 598                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
 599                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
 600        }
 601        tmp += sum;
 602    }
 603
 604#endif
 605
 606    // sum up partial sums and write back result
 607#pragma unroll
 608    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
 609        tmp +=
 610            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 611    }
 612
 613    if (tid == 0) {
 614        dst[row] = tmp;
 615    }
 616}
 617
 618/*
 619DPCT1110:7: The total declared local variable size in device function
 620dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
 621pressure. Consult with your hardware vendor to find the total register size
 622available and adjust the code, or use smaller sub-group size to avoid high
 623register pressure.
 624*/
 625static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
 626                                        const float *__restrict__ yy,
 627                                        float *__restrict__ dst,
 628                                        const int ncols,
 629                                        const sycl::nd_item<3> &item_ct1) {
 630
 631    const int row = item_ct1.get_group(2);
 632    const int num_blocks_per_row = ncols / QK_K;
 633    const int ib0 = row*num_blocks_per_row;
 634
 635    const block_q5_K * x = (const block_q5_K *)vx + ib0;
 636
 637    float tmp = 0; // partial sum for thread in warp
 638
 639#if QK_K == 256
 640    const uint16_t kmask1 = 0x3f3f;
 641    const uint16_t kmask2 = 0x0f0f;
 642    const uint16_t kmask3 = 0xc0c0;
 643
 644    const int tid = item_ct1.get_local_id(2) / 2; // 0...15
 645    const int ix = item_ct1.get_local_id(2) % 2;
 646
 647    const int il  = tid/4;     // 0...3
 648    const int ir  = tid - 4*il;// 0...3
 649    const int n   = 2;
 650
 651    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
 652    const int in = il%2;
 653
 654    const int l0 = n*(2*ir + in);
 655    const int q_offset = 32*im + l0;
 656    const int y_offset = 64*im + l0;
 657
 658    const uint8_t hm1  = 1 << (2*im);
 659    const uint8_t hm2  = hm1 << 4;
 660
 661    uint16_t aux[4];
 662    const uint8_t * sc = (const uint8_t *)aux;
 663
 664    uint16_t q16[8];
 665    const uint8_t * q4 = (const uint8_t *)q16;
 666
 667    for (int i = ix; i < num_blocks_per_row; i += 2) {
 668
 669        const uint8_t * ql1 = x[i].qs + q_offset;
 670        const uint8_t * qh  = x[i].qh + l0;
 671        const float   * y1  = yy + i*QK_K + y_offset;
 672        const float   * y2  = y1 + 128;
 673
 674        const float dall = x[i].dm[0];
 675        const float dmin = x[i].dm[1];
 676
 677        const uint16_t * a = (const uint16_t *)x[i].scales;
 678        aux[0] = a[im+0] & kmask1;
 679        aux[1] = a[im+2] & kmask1;
 680        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
 681        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
 682
 683        sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};
 684        float smin = 0;
 685        const uint16_t * q1 = (const uint16_t *)ql1;
 686        const uint16_t * q2 = q1 + 32;
 687        q16[0] = q1[0] & 0x0f0f;
 688        q16[1] = q1[8] & 0x0f0f;
 689        q16[2] = (q1[0] >> 4) & 0x0f0f;
 690        q16[3] = (q1[8] >> 4) & 0x0f0f;
 691        q16[4] = q2[0] & 0x0f0f;
 692        q16[5] = q2[8] & 0x0f0f;
 693        q16[6] = (q2[0] >> 4) & 0x0f0f;
 694        q16[7] = (q2[8] >> 4) & 0x0f0f;
 695        for (int l = 0; l < n; ++l) {
 696            sum.x() +=
 697                y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +
 698                y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));
 699            sum.y() +=
 700                y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +
 701                y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));
 702            sum.z() +=
 703                y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +
 704                y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));
 705            sum.w() +=
 706                y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +
 707                y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));
 708            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
 709                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
 710        }
 711        tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +
 712                       sum.w() * sc[5]) -
 713               dmin * smin;
 714    }
 715
 716#else
 717    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15
 718    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
 719    const int step = tid * K_QUANTS_PER_ITERATION;
 720    const int im = step/8;
 721    const int in = step%8;
 722
 723    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
 724        const uint8_t * q = x[i].qs + step;
 725        const int8_t  * s = x[i].scales;
 726        const float   * y = yy + i*QK_K + step;
 727        const float     d = x[i].d;
 728        float sum = 0.f;
 729        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
 730            const uint8_t h = x[i].qh[in+j] >> im;
 731            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
 732                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
 733                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
 734                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
 735        }
 736        tmp += sum;
 737    }
 738#endif
 739
 740    // sum up partial sums and write back result
 741#pragma unroll
 742    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
 743        tmp +=
 744            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 745    }
 746
 747    if (item_ct1.get_local_id(2) == 0) {
 748        dst[row] = tmp;
 749    }
 750}
 751
 752static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
 753                                        const sycl::nd_item<3> &item_ct1) {
 754
 755    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 756
 757    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
 758                    item_ct1.get_local_id(1);
 759    if (row > nrows) return;
 760
 761    const int num_blocks_per_row = ncols / QK_K;
 762    const int ib0 = row*num_blocks_per_row;
 763
 764    const block_q6_K * x = (const block_q6_K *)vx + ib0;
 765
 766#if QK_K == 256
 767
 768    const int tid =
 769        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
 770    const int ix =
 771        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
 772
 773    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
 774
 775    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
 776    const int in = tid - step*im;                        // 0...15 or 0...7
 777
 778#if K_QUANTS_PER_ITERATION == 1
 779    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
 780    const int is = 0;
 781#else
 782    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
 783    const int is = in / 4;
 784#endif
 785    const int ql_offset = 64*im + l0;
 786    const int qh_offset = 32*im + l0;
 787    const int s_offset  =  8*im + is;
 788    const int y_offset = 128*im + l0;
 789
 790    float tmp = 0; // partial sum for thread in warp
 791
 792    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 793
 794        const float   * y  = yy + i * QK_K + y_offset;
 795        const uint8_t * ql = x[i].ql + ql_offset;
 796        const uint8_t * qh = x[i].qh + qh_offset;
 797        const int8_t  * s  = x[i].scales + s_offset;
 798
 799        const float d = x[i].d;
 800
 801#if K_QUANTS_PER_ITERATION == 1
 802        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
 803                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
 804                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
 805                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
 806                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
 807                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
 808                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
 809                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
 810        tmp += sum;
 811#else
 812        float sum = 0;
 813        for (int l = 0; l < 4; ++l) {
 814            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
 815                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
 816                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
 817                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
 818        }
 819        tmp += sum;
 820#endif
 821
 822    }
 823
 824#else
 825
 826    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...7
 827    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0...3
 828
 829    const int step = tid * K_QUANTS_PER_ITERATION;
 830
 831    float tmp = 0; // partial sum for thread in warp
 832
 833    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
 834
 835        const float   * y  = yy + i * QK_K + step;
 836        const uint8_t * ql = x[i].ql + step;
 837        const uint8_t * qh = x[i].qh + step;
 838        const int8_t  * s  = x[i].scales;
 839
 840        const float d = x[i+0].d;
 841
 842        float sum = 0;
 843        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
 844            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
 845                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
 846                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
 847                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
 848        }
 849        tmp += sum;
 850
 851    }
 852
 853#endif
 854
 855    // sum up partial sums and write back result
 856#pragma unroll
 857    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
 858        tmp +=
 859            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
 860    }
 861
 862    if (tid == 0) {
 863        dst[row] = tmp;
 864    }
 865}
 866
 867static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
 868                                             float *dst, const int ncols,
 869                                             const int nrows,
 870                                             dpct::queue_ptr stream) {
 871    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 872    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 873    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
 874    const sycl::range<3> block_nums(1, 1, block_num_y);
 875    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 876    {
 877        dpct::has_capability_or_fail(stream->get_device(),
 878                                     {sycl::aspect::fp16});
 879
 880        stream->parallel_for(
 881            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 882            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 883                dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
 884                    vx, y, dst, ncols, nrows, item_ct1);
 885            });
 886    }
 887}
 888
 889
 890static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
 891                                             float *dst, const int ncols,
 892                                             const int nrows,
 893                                             dpct::queue_ptr stream) {
 894    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 895    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 896    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
 897    const sycl::range<3> block_nums(1, 1, block_num_y);
 898    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 899    {
 900        dpct::has_capability_or_fail(stream->get_device(),
 901                                     {sycl::aspect::fp16});
 902
 903        stream->parallel_for(
 904            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 905            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 906                dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
 907                    vx, y, dst, ncols, nrows, item_ct1);
 908            });
 909    }
 910}
 911
 912static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
 913                                             float *dst, const int ncols,
 914                                             const int nrows,
 915                                             dpct::queue_ptr stream) {
 916    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 917    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 918    const sycl::range<3> block_nums(1, 1, block_num_y);
 919    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 920    {
 921        dpct::has_capability_or_fail(stream->get_device(),
 922                                     {sycl::aspect::fp16});
 923
 924        stream->parallel_for(
 925            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 926            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 927                dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
 928                    vx, y, dst, ncols, nrows, item_ct1);
 929            });
 930    }
 931}
 932
 933static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
 934                                             float *dst, const int ncols,
 935                                             const int nrows,
 936                                             dpct::queue_ptr stream) {
 937    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 938    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 939    const sycl::range<3> block_nums(1, 1, block_num_y);
 940    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 941    {
 942        dpct::has_capability_or_fail(stream->get_device(),
 943                                     {sycl::aspect::fp16});
 944
 945        stream->parallel_for(
 946            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 947            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 948                dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
 949                    vx, y, dst, ncols, nrows, item_ct1);
 950            });
 951    }
 952}
 953
 954static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
 955                                             float *dst, const int ncols,
 956                                             const int nrows,
 957                                             dpct::queue_ptr stream) {
 958    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 959    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 960    const sycl::range<3> block_nums(1, 1, block_num_y);
 961    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 962    {
 963        dpct::has_capability_or_fail(stream->get_device(),
 964                                     {sycl::aspect::fp16});
 965
 966        stream->parallel_for(
 967            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 968            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 969                dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
 970                    vx, y, dst, ncols, nrows, item_ct1);
 971            });
 972    }
 973}
 974
 975static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
 976                                             float *dst, const int ncols,
 977                                             const int nrows,
 978                                             dpct::queue_ptr stream) {
 979    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
 980    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
 981    const sycl::range<3> block_nums(1, 1, block_num_y);
 982    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
 983    {
 984        dpct::has_capability_or_fail(stream->get_device(),
 985                                     {sycl::aspect::fp16});
 986
 987        stream->parallel_for(
 988            sycl::nd_range<3>(block_nums * block_dims, block_dims),
 989            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
 990                dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
 991                    vx, y, dst, ncols, nrows, item_ct1);
 992            });
 993    }
 994}
 995
 996static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
 997                                             float *dst, const int ncols,
 998                                             const int nrows,
 999                                             dpct::queue_ptr stream) {
1000    GGML_ASSERT(ncols % QK_K == 0);
1001    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
1002    const int block_num_y = (nrows + ny - 1) / ny;
1003    const sycl::range<3> block_nums(1, 1, block_num_y);
1004    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1005    stream->parallel_for(
1006        sycl::nd_range<3>(block_nums * block_dims, block_dims),
1007        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1008            dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
1009        });
1010}
1011
1012static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
1013                                             float *dst, const int ncols,
1014                                             const int nrows,
1015                                             dpct::queue_ptr stream) {
1016    GGML_ASSERT(ncols % QK_K == 0);
1017    const int ny = 2 / K_QUANTS_PER_ITERATION;
1018    const int block_num_y = (nrows + ny - 1) / ny;
1019    const sycl::range<3> block_nums(1, 1, block_num_y);
1020    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1021    stream->parallel_for(
1022        sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024            dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1025        });
1026}
1027
1028static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
1029                                             float *dst, const int ncols,
1030                                             const int nrows,
1031                                             dpct::queue_ptr stream) {
1032    GGML_ASSERT(ncols % QK_K == 0);
1033    const int ny = 2 / K_QUANTS_PER_ITERATION;
1034    const int block_num_y = (nrows + ny - 1) / ny;
1035    const sycl::range<3> block_nums(1, 1, block_num_y);
1036    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1037    stream->parallel_for(
1038        sycl::nd_range<3>(block_nums * block_dims, block_dims),
1039        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1040            dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1041        });
1042}
1043
1044static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
1045                                             float *dst, const int ncols,
1046                                             const int nrows,
1047                                             dpct::queue_ptr stream) {
1048    GGML_ASSERT(ncols % QK_K == 0);
1049    const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
1050    stream->parallel_for(
1051        sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1052        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1053            dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1054        });
1055}
1056
1057static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
1058                                             float *dst, const int ncols,
1059                                             const int nrows,
1060                                             dpct::queue_ptr stream) {
1061    GGML_ASSERT(ncols % QK_K == 0);
1062    const int ny = 2 / K_QUANTS_PER_ITERATION;
1063    const int block_num_y = (nrows + ny - 1) / ny;
1064    const sycl::range<3> block_nums(1, 1, block_num_y);
1065    const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1066    stream->parallel_for(
1067        sycl::nd_range<3>(block_nums * block_dims, block_dims),
1068        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1069            dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1070        });
1071}
1072
1073void ggml_sycl_op_dequantize_mul_mat_vec(
1074    ggml_backend_sycl_context & ctx,
1075    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
1076    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
1077    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
1078    const int64_t src1_ncols, const int64_t src1_padded_row_size,
1079    const dpct::queue_ptr &stream) {
1080
1081    const int64_t ne00 = src0->ne[0];
1082    const int64_t row_diff = row_high - row_low;
1083    GGML_ASSERT(src1->type == GGML_TYPE_F32);
1084    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
1085#ifdef GGML_SYCL_F16
1086    ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());
1087    sycl::half *src1_dfloat = nullptr; // dfloat == half
1088
1089    bool src1_convert_f16 =
1090        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
1091        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
1092        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
1093
1094    if (src1_convert_f16) {
1095        scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
1096                                             " : converting src1 to fp16");
1097        src1_dfloat = src1_dfloat_a.alloc(ne00);
1098        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
1099        GGML_ASSERT(to_fp16_sycl != nullptr);
1100        to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
1101    }
1102#else
1103    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
1104#endif // GGML_SYCL_F16
1105
1106    switch (src0->type) {
1107        case GGML_TYPE_Q4_0:
1108            if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
1109                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
1110                dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1111            } else {
1112                dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1113            }
1114            break;
1115        case GGML_TYPE_Q4_1:
1116            dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1117            break;
1118        case GGML_TYPE_Q5_0:
1119            dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1120            break;
1121        case GGML_TYPE_Q5_1:
1122            dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1123            break;
1124        case GGML_TYPE_Q8_0:
1125            dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1126            break;
1127        case GGML_TYPE_Q2_K:
1128            dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1129            break;
1130        case GGML_TYPE_Q3_K:
1131            dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1132            break;
1133        case GGML_TYPE_Q4_K:
1134            if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1135                ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1136                // reorder is currently not supported for dmmv
1137                GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
1138            } else {
1139                dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1140            }
1141            break;
1142        case GGML_TYPE_Q5_K:
1143            dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1144            break;
1145        case GGML_TYPE_Q6_K:
1146            dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1147            break;
1148        case GGML_TYPE_F16:
1149            convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1150            break;
1151        default:
1152            printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1153            GGML_ABORT("fatal error");
1154    }
1155
1156    GGML_UNUSED(src1);
1157    GGML_UNUSED(dst);
1158    GGML_UNUSED(src1_ddq_i);
1159    GGML_UNUSED(src1_ncols);
1160    GGML_UNUSED(src1_padded_row_size);
1161    GGML_UNUSED(ctx);
1162}