diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/dmmv.cpp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/dmmv.cpp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-sycl/dmmv.cpp | 1162 |
1 files changed, 1162 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp b/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp new file mode 100644 index 0000000..4f27601 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp | |||
| @@ -0,0 +1,1162 @@ | |||
| 1 | #include "convert.hpp" | ||
| 2 | #include "dmmv.hpp" | ||
| 3 | #include "dequantize.hpp" | ||
| 4 | #include "presets.hpp" | ||
| 5 | |||
| 6 | static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ | ||
| 7 | const sycl::half *x = (const sycl::half *)vx; | ||
| 8 | |||
| 9 | // automatic half -> float type cast if dfloat == float | ||
| 10 | v.x() = x[ib + iqs + 0]; | ||
| 11 | v.y() = x[ib + iqs + 1]; | ||
| 12 | } | ||
| 13 | |||
| 14 | static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ | ||
| 15 | const float * x = (const float *) vx; | ||
| 16 | |||
| 17 | // automatic half -> float type cast if dfloat == float | ||
| 18 | v.x() = x[ib + iqs + 0]; | ||
| 19 | v.y() = x[ib + iqs + 1]; | ||
| 20 | } | ||
| 21 | |||
| 22 | template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||
| 23 | static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, | ||
| 24 | const sycl::nd_item<3> &item_ct1) { | ||
| 25 | // qk = quantized weights per x block | ||
| 26 | // qr = number of quantized weights per data value in x block | ||
| 27 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 28 | item_ct1.get_local_id(1); | ||
| 29 | |||
| 30 | if (row >= nrows) { | ||
| 31 | return; | ||
| 32 | } | ||
| 33 | |||
| 34 | const int tid = item_ct1.get_local_id(2); | ||
| 35 | |||
| 36 | const int iter_stride = 2*GGML_SYCL_DMMV_X; | ||
| 37 | const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter | ||
| 38 | const int y_offset = qr == 1 ? 1 : qk/2; | ||
| 39 | |||
| 40 | // partial sum for each thread | ||
| 41 | #ifdef GGML_SYCL_F16 | ||
| 42 | sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics | ||
| 43 | #else | ||
| 44 | float tmp = 0.0f; | ||
| 45 | #endif // GGML_SYCL_F16 | ||
| 46 | |||
| 47 | for (int i = 0; i < ncols; i += iter_stride) { | ||
| 48 | const int col = i + vals_per_iter*tid; | ||
| 49 | const int ib = (row*ncols + col)/qk; // x block index | ||
| 50 | const int iqs = (col%qk)/qr; // x quant index | ||
| 51 | const int iybs = col - col%qk; // y block start index | ||
| 52 | |||
| 53 | // processing >2 values per i iter is faster for fast GPUs | ||
| 54 | #pragma unroll | ||
| 55 | for (int j = 0; j < vals_per_iter; j += 2) { | ||
| 56 | // process 2 vals per j iter | ||
| 57 | |||
| 58 | // dequantize | ||
| 59 | // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val | ||
| 60 | dfloat2 v; | ||
| 61 | dequantize_kernel(vx, ib, iqs + j/qr, v); | ||
| 62 | |||
| 63 | // matrix multiplication | ||
| 64 | // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 | ||
| 65 | #ifdef GGML_SYCL_F16 | ||
| 66 | dfloat2 t1{y[iybs + iqs + j / qr + 0], | ||
| 67 | y[iybs + iqs + j / qr + y_offset]}; | ||
| 68 | |||
| 69 | tmp += v * t1; | ||
| 70 | #else | ||
| 71 | tmp += v.x() * y[iybs + iqs + j / qr + 0]; | ||
| 72 | tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; | ||
| 73 | #endif // GGML_SYCL_F16 | ||
| 74 | } | ||
| 75 | } | ||
| 76 | |||
| 77 | // sum up partial sums and write back result | ||
| 78 | const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; | ||
| 79 | for (int mask = mask_start; mask > 0; mask >>= 1) { | ||
| 80 | tmp += | ||
| 81 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 82 | } | ||
| 83 | |||
| 84 | if (tid == 0) { | ||
| 85 | #ifdef GGML_SYCL_F16 | ||
| 86 | dst[row] = tmp.x() + tmp.y(); | ||
| 87 | #else | ||
| 88 | dst[row] = tmp; | ||
| 89 | #endif // GGML_SYCL_F16 | ||
| 90 | } | ||
| 91 | } | ||
| 92 | |||
| 93 | template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder> | ||
| 94 | static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, | ||
| 95 | const sycl::nd_item<3> &item_ct1) { | ||
| 96 | // qk = quantized weights per x block | ||
| 97 | // qr = number of quantized weights per data value in x block | ||
| 98 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 99 | item_ct1.get_local_id(1); | ||
| 100 | |||
| 101 | if (row >= nrows) { | ||
| 102 | return; | ||
| 103 | } | ||
| 104 | |||
| 105 | const int tid = item_ct1.get_local_id(2); | ||
| 106 | |||
| 107 | |||
| 108 | const int ncols_left = ncols % (QK4_0*WARP_SIZE); | ||
| 109 | const int ncols_align = ncols - ncols_left; | ||
| 110 | const int iter_stride = 8*2*GGML_SYCL_DMMV_X; | ||
| 111 | const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16 | ||
| 112 | const int y_offset = qr == 1 ? 1 : qk/2; | ||
| 113 | |||
| 114 | // partial sum for each thread | ||
| 115 | #ifdef GGML_SYCL_F16 | ||
| 116 | sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics | ||
| 117 | #else | ||
| 118 | float tmp = 0.0f; | ||
| 119 | #endif // GGML_SYCL_F16 | ||
| 120 | const char *d_ptr = (const char*)vx+ncols*nrows/2; | ||
| 121 | int i=0; | ||
| 122 | for (i = 0; i < ncols_align; i += iter_stride) { | ||
| 123 | const int col = i + vals_per_iter*tid; | ||
| 124 | const int ib = (row*ncols + col)/qk; // x block index | ||
| 125 | const int iqs = (col%qk)/qr; // x quant index | ||
| 126 | const int iybs = col - col%qk; // y block start index | ||
| 127 | |||
| 128 | // processing >2 values per i iter is faster for fast GPUs | ||
| 129 | #pragma unroll | ||
| 130 | for (int j = 0; j < vals_per_iter; j += 2) { | ||
| 131 | // process 2 vals per j iter | ||
| 132 | |||
| 133 | // dequantize | ||
| 134 | // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val | ||
| 135 | dfloat2 v; | ||
| 136 | dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); | ||
| 137 | |||
| 138 | // matrix multiplication | ||
| 139 | // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 | ||
| 140 | #ifdef GGML_SYCL_F16 | ||
| 141 | dfloat2 t1{y[iybs + iqs + j / qr + 0], | ||
| 142 | y[iybs + iqs + j / qr + y_offset]}; | ||
| 143 | |||
| 144 | tmp += v * t1; | ||
| 145 | #else | ||
| 146 | tmp += v.x() * y[iybs + iqs + j / qr + 0]; | ||
| 147 | tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; | ||
| 148 | #endif // GGML_SYCL_F16 | ||
| 149 | } | ||
| 150 | } | ||
| 151 | |||
| 152 | for (; i < ncols; i += iter_stride) { | ||
| 153 | if (tid>=ncols_left/QK4_0) continue; | ||
| 154 | const int col = i + vals_per_iter*tid; | ||
| 155 | const int ib = (row*ncols + col)/qk; // x block index | ||
| 156 | const int iqs = (col%qk)/qr; // x quant index | ||
| 157 | const int iybs = col - col%qk; // y block start index | ||
| 158 | |||
| 159 | // processing >2 values per i iter is faster for fast GPUs | ||
| 160 | #pragma unroll | ||
| 161 | for (int j = 0; j < vals_per_iter; j += 2) { | ||
| 162 | // process 2 vals per j iter | ||
| 163 | |||
| 164 | // dequantize | ||
| 165 | // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val | ||
| 166 | dfloat2 v; | ||
| 167 | dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); | ||
| 168 | |||
| 169 | // matrix multiplication | ||
| 170 | // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 | ||
| 171 | #ifdef GGML_SYCL_F16 | ||
| 172 | dfloat2 t1{y[iybs + iqs + j / qr + 0], | ||
| 173 | y[iybs + iqs + j / qr + y_offset]}; | ||
| 174 | |||
| 175 | tmp += v * t1; | ||
| 176 | #else | ||
| 177 | tmp += v.x() * y[iybs + iqs + j / qr + 0]; | ||
| 178 | tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; | ||
| 179 | #endif // GGML_SYCL_F16 | ||
| 180 | } | ||
| 181 | } | ||
| 182 | |||
| 183 | // sum up partial sums and write back result | ||
| 184 | const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; | ||
| 185 | for (int mask = mask_start; mask > 0; mask >>= 1) { | ||
| 186 | tmp += | ||
| 187 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 188 | } | ||
| 189 | |||
| 190 | if (tid == 0) { | ||
| 191 | #ifdef GGML_SYCL_F16 | ||
| 192 | dst[row] = tmp.x() + tmp.y(); | ||
| 193 | #else | ||
| 194 | dst[row] = tmp; | ||
| 195 | #endif // GGML_SYCL_F16 | ||
| 196 | } | ||
| 197 | } | ||
| 198 | |||
| 199 | static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, | ||
| 200 | float *dst, const int ncols, | ||
| 201 | const int nrows, | ||
| 202 | dpct::queue_ptr stream) { | ||
| 203 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 204 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 205 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 206 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 207 | { | ||
| 208 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 209 | {sycl::aspect::fp16}); | ||
| 210 | |||
| 211 | stream->parallel_for( | ||
| 212 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 213 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 214 | dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, | ||
| 215 | nrows, item_ct1); | ||
| 216 | }); | ||
| 217 | } | ||
| 218 | } | ||
| 219 | |||
| 220 | /* | ||
| 221 | DPCT1110:4: The total declared local variable size in device function | ||
| 222 | dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register | ||
| 223 | pressure. Consult with your hardware vendor to find the total register size | ||
| 224 | available and adjust the code, or use smaller sub-group size to avoid high | ||
| 225 | register pressure. | ||
| 226 | */ | ||
| 227 | static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, | ||
| 228 | const float *__restrict__ yy, | ||
| 229 | float *__restrict__ dst, | ||
| 230 | const int ncols, int nrows, | ||
| 231 | const sycl::nd_item<3> &item_ct1) { | ||
| 232 | |||
| 233 | static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); | ||
| 234 | |||
| 235 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 236 | item_ct1.get_local_id(1); | ||
| 237 | if (row > nrows) return; | ||
| 238 | |||
| 239 | const int num_blocks_per_row = ncols / QK_K; | ||
| 240 | const int ib0 = row*num_blocks_per_row; | ||
| 241 | |||
| 242 | const block_q2_K * x = (const block_q2_K *)vx + ib0; | ||
| 243 | |||
| 244 | float tmp = 0; // partial sum for thread in warp | ||
| 245 | |||
| 246 | #if QK_K == 256 | ||
| 247 | const int tid = | ||
| 248 | item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15 | ||
| 249 | const int ix = | ||
| 250 | item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 | ||
| 251 | |||
| 252 | const int step = 16/K_QUANTS_PER_ITERATION; | ||
| 253 | |||
| 254 | const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... | ||
| 255 | const int in = tid - step*im; // 0...15 or 0...7 | ||
| 256 | |||
| 257 | const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 | ||
| 258 | const int q_offset = 32*im + l0; | ||
| 259 | const int s_offset = 8*im; | ||
| 260 | const int y_offset = 128*im + l0; | ||
| 261 | |||
| 262 | uint32_t aux[4]; | ||
| 263 | const uint8_t * d = (const uint8_t *)aux; | ||
| 264 | const uint8_t * m = (const uint8_t *)(aux + 2); | ||
| 265 | |||
| 266 | for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { | ||
| 267 | |||
| 268 | const float * y = yy + i * QK_K + y_offset; | ||
| 269 | const uint8_t * q = x[i].qs + q_offset; | ||
| 270 | |||
| 271 | const float dall = x[i].dm[0]; | ||
| 272 | const float dmin = x[i].dm[1]; | ||
| 273 | |||
| 274 | const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); | ||
| 275 | aux[0] = a[0] & 0x0f0f0f0f; | ||
| 276 | aux[1] = a[1] & 0x0f0f0f0f; | ||
| 277 | aux[2] = (a[0] >> 4) & 0x0f0f0f0f; | ||
| 278 | aux[3] = (a[1] >> 4) & 0x0f0f0f0f; | ||
| 279 | |||
| 280 | float sum1 = 0, sum2 = 0; | ||
| 281 | for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { | ||
| 282 | sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) | ||
| 283 | + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) | ||
| 284 | + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) | ||
| 285 | + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) | ||
| 286 | + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) | ||
| 287 | + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) | ||
| 288 | + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) | ||
| 289 | +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); | ||
| 290 | sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] | ||
| 291 | + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; | ||
| 292 | |||
| 293 | } | ||
| 294 | tmp += dall * sum1 - dmin * sum2; | ||
| 295 | |||
| 296 | } | ||
| 297 | #else | ||
| 298 | const int tid = item_ct1.get_local_id(2) / | ||
| 299 | (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 | ||
| 300 | const int ix = item_ct1.get_local_id(2) % | ||
| 301 | (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3 | ||
| 302 | const int offset = tid * K_QUANTS_PER_ITERATION; | ||
| 303 | |||
| 304 | uint32_t uaux[2]; | ||
| 305 | const uint8_t * d = (const uint8_t *)uaux; | ||
| 306 | |||
| 307 | |||
| 308 | for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { | ||
| 309 | |||
| 310 | const float * y = yy + i * QK_K + offset; | ||
| 311 | const uint8_t * q = x[i].qs + offset; | ||
| 312 | const uint32_t * s = (const uint32_t *)x[i].scales; | ||
| 313 | |||
| 314 | uaux[0] = s[0] & 0x0f0f0f0f; | ||
| 315 | uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; | ||
| 316 | |||
| 317 | const sycl::float2 dall = | ||
| 318 | x[i].dm.convert<float, sycl::rounding_mode::automatic>(); | ||
| 319 | |||
| 320 | float sum1 = 0, sum2 = 0; | ||
| 321 | for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { | ||
| 322 | const uint8_t ql = q[l]; | ||
| 323 | sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) | ||
| 324 | + y[l+16] * d[1] * ((ql >> 2) & 3) | ||
| 325 | + y[l+32] * d[2] * ((ql >> 4) & 3) | ||
| 326 | + y[l+48] * d[3] * ((ql >> 6) & 3); | ||
| 327 | sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; | ||
| 328 | } | ||
| 329 | tmp += dall.x() * sum1 - dall.y() * sum2; | ||
| 330 | } | ||
| 331 | |||
| 332 | #endif | ||
| 333 | |||
| 334 | // sum up partial sums and write back result | ||
| 335 | #pragma unroll | ||
| 336 | for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||
| 337 | tmp += | ||
| 338 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 339 | } | ||
| 340 | |||
| 341 | if (item_ct1.get_local_id(2) == 0) { | ||
| 342 | dst[row] = tmp; | ||
| 343 | } | ||
| 344 | } | ||
| 345 | |||
| 346 | /* | ||
| 347 | DPCT1110:5: The total declared local variable size in device function | ||
| 348 | dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register | ||
| 349 | pressure. Consult with your hardware vendor to find the total register size | ||
| 350 | available and adjust the code, or use smaller sub-group size to avoid high | ||
| 351 | register pressure. | ||
| 352 | */ | ||
| 353 | static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, | ||
| 354 | const float *__restrict__ yy, | ||
| 355 | float *__restrict__ dst, | ||
| 356 | const int ncols, int nrows, | ||
| 357 | const sycl::nd_item<3> &item_ct1) { | ||
| 358 | |||
| 359 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 360 | item_ct1.get_local_id(1); | ||
| 361 | if (row > nrows) return; | ||
| 362 | |||
| 363 | const int num_blocks_per_row = ncols / QK_K; | ||
| 364 | const int ib0 = row*num_blocks_per_row; | ||
| 365 | |||
| 366 | const block_q3_K * x = (const block_q3_K *)vx + ib0; | ||
| 367 | |||
| 368 | float tmp = 0; // partial sum for thread in warp | ||
| 369 | |||
| 370 | #if QK_K == 256 | ||
| 371 | |||
| 372 | const uint16_t kmask1 = 0x0303; | ||
| 373 | const uint16_t kmask2 = 0x0f0f; | ||
| 374 | |||
| 375 | const int tid = | ||
| 376 | item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 | ||
| 377 | const int ix = | ||
| 378 | item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 | ||
| 379 | |||
| 380 | const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop | ||
| 381 | const int step = 16/K_QUANTS_PER_ITERATION; | ||
| 382 | const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... | ||
| 383 | const int in = tid - step*im; // 0....15 or 0...7 | ||
| 384 | |||
| 385 | const uint8_t m = 1 << (4*im); | ||
| 386 | |||
| 387 | const int l0 = n*in; // 0...15 or 0...14 in steps of 2 | ||
| 388 | const int q_offset = 32*im + l0; | ||
| 389 | const int y_offset = 128*im + l0; | ||
| 390 | |||
| 391 | uint16_t utmp[4]; | ||
| 392 | const int8_t * s = (const int8_t *)utmp; | ||
| 393 | |||
| 394 | const uint16_t s_shift = 4*im; | ||
| 395 | |||
| 396 | for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { | ||
| 397 | |||
| 398 | const float * y = yy + i * QK_K + y_offset; | ||
| 399 | const uint8_t * q = x[i].qs + q_offset; | ||
| 400 | const uint8_t * h = x[i].hmask + l0; | ||
| 401 | |||
| 402 | const uint16_t * a = (const uint16_t *)x[i].scales; | ||
| 403 | utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); | ||
| 404 | utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); | ||
| 405 | utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); | ||
| 406 | utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); | ||
| 407 | |||
| 408 | const float d = x[i].d; | ||
| 409 | |||
| 410 | float sum = 0; | ||
| 411 | for (int l = 0; l < n; ++l) { | ||
| 412 | sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) | ||
| 413 | + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) | ||
| 414 | + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) | ||
| 415 | + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); | ||
| 416 | sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) | ||
| 417 | + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) | ||
| 418 | + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) | ||
| 419 | + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); | ||
| 420 | } | ||
| 421 | tmp += d * sum; | ||
| 422 | |||
| 423 | } | ||
| 424 | #else | ||
| 425 | |||
| 426 | const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 | ||
| 427 | const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 | ||
| 428 | const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 | ||
| 429 | const int in = offset/8; // 0 or 1 | ||
| 430 | const int im = offset%8; // 0...7 | ||
| 431 | |||
| 432 | for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { | ||
| 433 | |||
| 434 | const float * y = yy + i * QK_K + offset; | ||
| 435 | const uint8_t * q = x[i].qs + offset; | ||
| 436 | const uint8_t * s = x[i].scales; | ||
| 437 | |||
| 438 | const float dall = (float)x[i].d; | ||
| 439 | |||
| 440 | float sum = 0; | ||
| 441 | for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { | ||
| 442 | const uint8_t hl = x[i].hmask[im+l] >> in; | ||
| 443 | const uint8_t ql = q[l]; | ||
| 444 | sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) | ||
| 445 | + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) | ||
| 446 | + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) | ||
| 447 | + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); | ||
| 448 | } | ||
| 449 | tmp += sum; | ||
| 450 | } | ||
| 451 | #endif | ||
| 452 | |||
| 453 | // sum up partial sums and write back result | ||
| 454 | #pragma unroll | ||
| 455 | for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||
| 456 | tmp += | ||
| 457 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 458 | } | ||
| 459 | |||
| 460 | if (item_ct1.get_local_id(2) == 0) { | ||
| 461 | dst[row] = tmp; | ||
| 462 | } | ||
| 463 | } | ||
| 464 | |||
| 465 | /* | ||
| 466 | DPCT1110:6: The total declared local variable size in device function | ||
| 467 | dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register | ||
| 468 | pressure. Consult with your hardware vendor to find the total register size | ||
| 469 | available and adjust the code, or use smaller sub-group size to avoid high | ||
| 470 | register pressure. | ||
| 471 | */ | ||
| 472 | static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, | ||
| 473 | const float *__restrict__ yy, | ||
| 474 | float *__restrict__ dst, | ||
| 475 | const int ncols, int nrows, | ||
| 476 | const sycl::nd_item<3> &item_ct1) { | ||
| 477 | |||
| 478 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 479 | item_ct1.get_local_id(1); | ||
| 480 | if (row > nrows) return; | ||
| 481 | const int num_blocks_per_row = ncols / QK_K; | ||
| 482 | const int ib0 = row*num_blocks_per_row; | ||
| 483 | |||
| 484 | const block_q4_K * x = (const block_q4_K *)vx + ib0; | ||
| 485 | |||
| 486 | #if QK_K == 256 | ||
| 487 | const uint16_t kmask1 = 0x3f3f; | ||
| 488 | const uint16_t kmask2 = 0x0f0f; | ||
| 489 | const uint16_t kmask3 = 0xc0c0; | ||
| 490 | |||
| 491 | const int tid = | ||
| 492 | item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 | ||
| 493 | const int ix = | ||
| 494 | item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 | ||
| 495 | |||
| 496 | const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 | ||
| 497 | |||
| 498 | const int il = tid/step; // 0...3 | ||
| 499 | const int ir = tid - step*il; // 0...7 or 0...3 | ||
| 500 | const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 | ||
| 501 | |||
| 502 | const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 | ||
| 503 | const int in = il%2; | ||
| 504 | |||
| 505 | const int l0 = n*(2*ir + in); | ||
| 506 | const int q_offset = 32*im + l0; | ||
| 507 | const int y_offset = 64*im + l0; | ||
| 508 | |||
| 509 | uint16_t aux[4]; | ||
| 510 | const uint8_t * sc = (const uint8_t *)aux; | ||
| 511 | |||
| 512 | #if K_QUANTS_PER_ITERATION == 2 | ||
| 513 | uint32_t q32[4]; | ||
| 514 | const uint8_t * q4 = (const uint8_t *)q32; | ||
| 515 | #else | ||
| 516 | uint16_t q16[4]; | ||
| 517 | const uint8_t * q4 = (const uint8_t *)q16; | ||
| 518 | #endif | ||
| 519 | |||
| 520 | float tmp = 0; // partial sum for thread in warp | ||
| 521 | |||
| 522 | for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { | ||
| 523 | |||
| 524 | const float * y1 = yy + i*QK_K + y_offset; | ||
| 525 | const float * y2 = y1 + 128; | ||
| 526 | |||
| 527 | const float dall = x[i].dm[0]; | ||
| 528 | const float dmin = x[i].dm[1]; | ||
| 529 | |||
| 530 | const uint16_t * a = (const uint16_t *)x[i].scales; | ||
| 531 | aux[0] = a[im+0] & kmask1; | ||
| 532 | aux[1] = a[im+2] & kmask1; | ||
| 533 | aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); | ||
| 534 | aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); | ||
| 535 | |||
| 536 | #if K_QUANTS_PER_ITERATION == 2 | ||
| 537 | const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); | ||
| 538 | const uint32_t * q2 = q1 + 16; | ||
| 539 | |||
| 540 | q32[0] = q1[0] & 0x0f0f0f0f; | ||
| 541 | q32[1] = q1[0] & 0xf0f0f0f0; | ||
| 542 | q32[2] = q2[0] & 0x0f0f0f0f; | ||
| 543 | q32[3] = q2[0] & 0xf0f0f0f0; | ||
| 544 | |||
| 545 | sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; | ||
| 546 | float smin = 0; | ||
| 547 | for (int l = 0; l < 4; ++l) { | ||
| 548 | s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; | ||
| 549 | s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; | ||
| 550 | smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; | ||
| 551 | } | ||
| 552 | tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + | ||
| 553 | s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - | ||
| 554 | dmin * smin; | ||
| 555 | #else | ||
| 556 | const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); | ||
| 557 | const uint16_t * q2 = q1 + 32; | ||
| 558 | |||
| 559 | q16[0] = q1[0] & 0x0f0f; | ||
| 560 | q16[1] = q1[0] & 0xf0f0; | ||
| 561 | q16[2] = q2[0] & 0x0f0f; | ||
| 562 | q16[3] = q2[0] & 0xf0f0; | ||
| 563 | |||
| 564 | float4 s = {0.f, 0.f, 0.f, 0.f}; | ||
| 565 | float smin = 0; | ||
| 566 | for (int l = 0; l < 2; ++l) { | ||
| 567 | s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; | ||
| 568 | s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; | ||
| 569 | smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; | ||
| 570 | } | ||
| 571 | tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; | ||
| 572 | #endif | ||
| 573 | |||
| 574 | } | ||
| 575 | #else | ||
| 576 | const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 | ||
| 577 | const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); | ||
| 578 | |||
| 579 | const int step = tid * K_QUANTS_PER_ITERATION; | ||
| 580 | |||
| 581 | uint16_t aux16[2]; | ||
| 582 | const uint8_t * s = (const uint8_t *)aux16; | ||
| 583 | |||
| 584 | float tmp = 0; | ||
| 585 | |||
| 586 | for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { | ||
| 587 | const uint8_t * q = x[i].qs + step; | ||
| 588 | const float * y = yy + i*QK_K + step; | ||
| 589 | const uint16_t * a = (const uint16_t *)x[i].scales; | ||
| 590 | aux16[0] = a[0] & 0x0f0f; | ||
| 591 | aux16[1] = (a[0] >> 4) & 0x0f0f; | ||
| 592 | const float d = (float)x[i].dm[0]; | ||
| 593 | const float m = (float)x[i].dm[1]; | ||
| 594 | float sum = 0.f; | ||
| 595 | for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { | ||
| 596 | sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) | ||
| 597 | + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) | ||
| 598 | + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) | ||
| 599 | + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); | ||
| 600 | } | ||
| 601 | tmp += sum; | ||
| 602 | } | ||
| 603 | |||
| 604 | #endif | ||
| 605 | |||
| 606 | // sum up partial sums and write back result | ||
| 607 | #pragma unroll | ||
| 608 | for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||
| 609 | tmp += | ||
| 610 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 611 | } | ||
| 612 | |||
| 613 | if (tid == 0) { | ||
| 614 | dst[row] = tmp; | ||
| 615 | } | ||
| 616 | } | ||
| 617 | |||
| 618 | /* | ||
| 619 | DPCT1110:7: The total declared local variable size in device function | ||
| 620 | dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register | ||
| 621 | pressure. Consult with your hardware vendor to find the total register size | ||
| 622 | available and adjust the code, or use smaller sub-group size to avoid high | ||
| 623 | register pressure. | ||
| 624 | */ | ||
| 625 | static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, | ||
| 626 | const float *__restrict__ yy, | ||
| 627 | float *__restrict__ dst, | ||
| 628 | const int ncols, | ||
| 629 | const sycl::nd_item<3> &item_ct1) { | ||
| 630 | |||
| 631 | const int row = item_ct1.get_group(2); | ||
| 632 | const int num_blocks_per_row = ncols / QK_K; | ||
| 633 | const int ib0 = row*num_blocks_per_row; | ||
| 634 | |||
| 635 | const block_q5_K * x = (const block_q5_K *)vx + ib0; | ||
| 636 | |||
| 637 | float tmp = 0; // partial sum for thread in warp | ||
| 638 | |||
| 639 | #if QK_K == 256 | ||
| 640 | const uint16_t kmask1 = 0x3f3f; | ||
| 641 | const uint16_t kmask2 = 0x0f0f; | ||
| 642 | const uint16_t kmask3 = 0xc0c0; | ||
| 643 | |||
| 644 | const int tid = item_ct1.get_local_id(2) / 2; // 0...15 | ||
| 645 | const int ix = item_ct1.get_local_id(2) % 2; | ||
| 646 | |||
| 647 | const int il = tid/4; // 0...3 | ||
| 648 | const int ir = tid - 4*il;// 0...3 | ||
| 649 | const int n = 2; | ||
| 650 | |||
| 651 | const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 | ||
| 652 | const int in = il%2; | ||
| 653 | |||
| 654 | const int l0 = n*(2*ir + in); | ||
| 655 | const int q_offset = 32*im + l0; | ||
| 656 | const int y_offset = 64*im + l0; | ||
| 657 | |||
| 658 | const uint8_t hm1 = 1 << (2*im); | ||
| 659 | const uint8_t hm2 = hm1 << 4; | ||
| 660 | |||
| 661 | uint16_t aux[4]; | ||
| 662 | const uint8_t * sc = (const uint8_t *)aux; | ||
| 663 | |||
| 664 | uint16_t q16[8]; | ||
| 665 | const uint8_t * q4 = (const uint8_t *)q16; | ||
| 666 | |||
| 667 | for (int i = ix; i < num_blocks_per_row; i += 2) { | ||
| 668 | |||
| 669 | const uint8_t * ql1 = x[i].qs + q_offset; | ||
| 670 | const uint8_t * qh = x[i].qh + l0; | ||
| 671 | const float * y1 = yy + i*QK_K + y_offset; | ||
| 672 | const float * y2 = y1 + 128; | ||
| 673 | |||
| 674 | const float dall = x[i].dm[0]; | ||
| 675 | const float dmin = x[i].dm[1]; | ||
| 676 | |||
| 677 | const uint16_t * a = (const uint16_t *)x[i].scales; | ||
| 678 | aux[0] = a[im+0] & kmask1; | ||
| 679 | aux[1] = a[im+2] & kmask1; | ||
| 680 | aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); | ||
| 681 | aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); | ||
| 682 | |||
| 683 | sycl::float4 sum = {0.f, 0.f, 0.f, 0.f}; | ||
| 684 | float smin = 0; | ||
| 685 | const uint16_t * q1 = (const uint16_t *)ql1; | ||
| 686 | const uint16_t * q2 = q1 + 32; | ||
| 687 | q16[0] = q1[0] & 0x0f0f; | ||
| 688 | q16[1] = q1[8] & 0x0f0f; | ||
| 689 | q16[2] = (q1[0] >> 4) & 0x0f0f; | ||
| 690 | q16[3] = (q1[8] >> 4) & 0x0f0f; | ||
| 691 | q16[4] = q2[0] & 0x0f0f; | ||
| 692 | q16[5] = q2[8] & 0x0f0f; | ||
| 693 | q16[6] = (q2[0] >> 4) & 0x0f0f; | ||
| 694 | q16[7] = (q2[8] >> 4) & 0x0f0f; | ||
| 695 | for (int l = 0; l < n; ++l) { | ||
| 696 | sum.x() += | ||
| 697 | y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) + | ||
| 698 | y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0)); | ||
| 699 | sum.y() += | ||
| 700 | y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) + | ||
| 701 | y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0)); | ||
| 702 | sum.z() += | ||
| 703 | y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) + | ||
| 704 | y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0)); | ||
| 705 | sum.w() += | ||
| 706 | y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) + | ||
| 707 | y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0)); | ||
| 708 | smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] | ||
| 709 | + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; | ||
| 710 | } | ||
| 711 | tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] + | ||
| 712 | sum.w() * sc[5]) - | ||
| 713 | dmin * smin; | ||
| 714 | } | ||
| 715 | |||
| 716 | #else | ||
| 717 | const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 | ||
| 718 | const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); | ||
| 719 | const int step = tid * K_QUANTS_PER_ITERATION; | ||
| 720 | const int im = step/8; | ||
| 721 | const int in = step%8; | ||
| 722 | |||
| 723 | for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { | ||
| 724 | const uint8_t * q = x[i].qs + step; | ||
| 725 | const int8_t * s = x[i].scales; | ||
| 726 | const float * y = yy + i*QK_K + step; | ||
| 727 | const float d = x[i].d; | ||
| 728 | float sum = 0.f; | ||
| 729 | for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { | ||
| 730 | const uint8_t h = x[i].qh[in+j] >> im; | ||
| 731 | sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) | ||
| 732 | + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) | ||
| 733 | + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) | ||
| 734 | + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); | ||
| 735 | } | ||
| 736 | tmp += sum; | ||
| 737 | } | ||
| 738 | #endif | ||
| 739 | |||
| 740 | // sum up partial sums and write back result | ||
| 741 | #pragma unroll | ||
| 742 | for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||
| 743 | tmp += | ||
| 744 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 745 | } | ||
| 746 | |||
| 747 | if (item_ct1.get_local_id(2) == 0) { | ||
| 748 | dst[row] = tmp; | ||
| 749 | } | ||
| 750 | } | ||
| 751 | |||
| 752 | static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, | ||
| 753 | const sycl::nd_item<3> &item_ct1) { | ||
| 754 | |||
| 755 | static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); | ||
| 756 | |||
| 757 | const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + | ||
| 758 | item_ct1.get_local_id(1); | ||
| 759 | if (row > nrows) return; | ||
| 760 | |||
| 761 | const int num_blocks_per_row = ncols / QK_K; | ||
| 762 | const int ib0 = row*num_blocks_per_row; | ||
| 763 | |||
| 764 | const block_q6_K * x = (const block_q6_K *)vx + ib0; | ||
| 765 | |||
| 766 | #if QK_K == 256 | ||
| 767 | |||
| 768 | const int tid = | ||
| 769 | item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 | ||
| 770 | const int ix = | ||
| 771 | item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 | ||
| 772 | |||
| 773 | const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 | ||
| 774 | |||
| 775 | const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... | ||
| 776 | const int in = tid - step*im; // 0...15 or 0...7 | ||
| 777 | |||
| 778 | #if K_QUANTS_PER_ITERATION == 1 | ||
| 779 | const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 | ||
| 780 | const int is = 0; | ||
| 781 | #else | ||
| 782 | const int l0 = 4 * in; // 0, 4, 8, ..., 28 | ||
| 783 | const int is = in / 4; | ||
| 784 | #endif | ||
| 785 | const int ql_offset = 64*im + l0; | ||
| 786 | const int qh_offset = 32*im + l0; | ||
| 787 | const int s_offset = 8*im + is; | ||
| 788 | const int y_offset = 128*im + l0; | ||
| 789 | |||
| 790 | float tmp = 0; // partial sum for thread in warp | ||
| 791 | |||
| 792 | for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { | ||
| 793 | |||
| 794 | const float * y = yy + i * QK_K + y_offset; | ||
| 795 | const uint8_t * ql = x[i].ql + ql_offset; | ||
| 796 | const uint8_t * qh = x[i].qh + qh_offset; | ||
| 797 | const int8_t * s = x[i].scales + s_offset; | ||
| 798 | |||
| 799 | const float d = x[i].d; | ||
| 800 | |||
| 801 | #if K_QUANTS_PER_ITERATION == 1 | ||
| 802 | float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) | ||
| 803 | + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) | ||
| 804 | + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) | ||
| 805 | + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) | ||
| 806 | + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) | ||
| 807 | + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) | ||
| 808 | + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) | ||
| 809 | +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); | ||
| 810 | tmp += sum; | ||
| 811 | #else | ||
| 812 | float sum = 0; | ||
| 813 | for (int l = 0; l < 4; ++l) { | ||
| 814 | sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) | ||
| 815 | + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) | ||
| 816 | + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) | ||
| 817 | + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); | ||
| 818 | } | ||
| 819 | tmp += sum; | ||
| 820 | #endif | ||
| 821 | |||
| 822 | } | ||
| 823 | |||
| 824 | #else | ||
| 825 | |||
| 826 | const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 | ||
| 827 | const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 | ||
| 828 | |||
| 829 | const int step = tid * K_QUANTS_PER_ITERATION; | ||
| 830 | |||
| 831 | float tmp = 0; // partial sum for thread in warp | ||
| 832 | |||
| 833 | for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { | ||
| 834 | |||
| 835 | const float * y = yy + i * QK_K + step; | ||
| 836 | const uint8_t * ql = x[i].ql + step; | ||
| 837 | const uint8_t * qh = x[i].qh + step; | ||
| 838 | const int8_t * s = x[i].scales; | ||
| 839 | |||
| 840 | const float d = x[i+0].d; | ||
| 841 | |||
| 842 | float sum = 0; | ||
| 843 | for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { | ||
| 844 | sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) | ||
| 845 | + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) | ||
| 846 | + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) | ||
| 847 | + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); | ||
| 848 | } | ||
| 849 | tmp += sum; | ||
| 850 | |||
| 851 | } | ||
| 852 | |||
| 853 | #endif | ||
| 854 | |||
| 855 | // sum up partial sums and write back result | ||
| 856 | #pragma unroll | ||
| 857 | for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||
| 858 | tmp += | ||
| 859 | dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||
| 860 | } | ||
| 861 | |||
| 862 | if (tid == 0) { | ||
| 863 | dst[row] = tmp; | ||
| 864 | } | ||
| 865 | } | ||
| 866 | |||
| 867 | static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, | ||
| 868 | float *dst, const int ncols, | ||
| 869 | const int nrows, | ||
| 870 | dpct::queue_ptr stream) { | ||
| 871 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 872 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 873 | // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead | ||
| 874 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 875 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 876 | { | ||
| 877 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 878 | {sycl::aspect::fp16}); | ||
| 879 | |||
| 880 | stream->parallel_for( | ||
| 881 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 882 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 883 | dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>( | ||
| 884 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 885 | }); | ||
| 886 | } | ||
| 887 | } | ||
| 888 | |||
| 889 | |||
| 890 | static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, | ||
| 891 | float *dst, const int ncols, | ||
| 892 | const int nrows, | ||
| 893 | dpct::queue_ptr stream) { | ||
| 894 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 895 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 896 | // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead | ||
| 897 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 898 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 899 | { | ||
| 900 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 901 | {sycl::aspect::fp16}); | ||
| 902 | |||
| 903 | stream->parallel_for( | ||
| 904 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 905 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 906 | dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>( | ||
| 907 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 908 | }); | ||
| 909 | } | ||
| 910 | } | ||
| 911 | |||
| 912 | static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y, | ||
| 913 | float *dst, const int ncols, | ||
| 914 | const int nrows, | ||
| 915 | dpct::queue_ptr stream) { | ||
| 916 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 917 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 918 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 919 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 920 | { | ||
| 921 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 922 | {sycl::aspect::fp16}); | ||
| 923 | |||
| 924 | stream->parallel_for( | ||
| 925 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 926 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 927 | dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>( | ||
| 928 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 929 | }); | ||
| 930 | } | ||
| 931 | } | ||
| 932 | |||
| 933 | static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y, | ||
| 934 | float *dst, const int ncols, | ||
| 935 | const int nrows, | ||
| 936 | dpct::queue_ptr stream) { | ||
| 937 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 938 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 939 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 940 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 941 | { | ||
| 942 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 943 | {sycl::aspect::fp16}); | ||
| 944 | |||
| 945 | stream->parallel_for( | ||
| 946 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 947 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 948 | dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>( | ||
| 949 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 950 | }); | ||
| 951 | } | ||
| 952 | } | ||
| 953 | |||
| 954 | static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, | ||
| 955 | float *dst, const int ncols, | ||
| 956 | const int nrows, | ||
| 957 | dpct::queue_ptr stream) { | ||
| 958 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 959 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 960 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 961 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 962 | { | ||
| 963 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 964 | {sycl::aspect::fp16}); | ||
| 965 | |||
| 966 | stream->parallel_for( | ||
| 967 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 968 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 969 | dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>( | ||
| 970 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 971 | }); | ||
| 972 | } | ||
| 973 | } | ||
| 974 | |||
| 975 | static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, | ||
| 976 | float *dst, const int ncols, | ||
| 977 | const int nrows, | ||
| 978 | dpct::queue_ptr stream) { | ||
| 979 | GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); | ||
| 980 | const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; | ||
| 981 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 982 | const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); | ||
| 983 | { | ||
| 984 | dpct::has_capability_or_fail(stream->get_device(), | ||
| 985 | {sycl::aspect::fp16}); | ||
| 986 | |||
| 987 | stream->parallel_for( | ||
| 988 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 989 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||
| 990 | dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>( | ||
| 991 | vx, y, dst, ncols, nrows, item_ct1); | ||
| 992 | }); | ||
| 993 | } | ||
| 994 | } | ||
| 995 | |||
| 996 | static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, | ||
| 997 | float *dst, const int ncols, | ||
| 998 | const int nrows, | ||
| 999 | dpct::queue_ptr stream) { | ||
| 1000 | GGML_ASSERT(ncols % QK_K == 0); | ||
| 1001 | const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 | ||
| 1002 | const int block_num_y = (nrows + ny - 1) / ny; | ||
| 1003 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 1004 | const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||
| 1005 | stream->parallel_for( | ||
| 1006 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 1007 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||
| 1008 | dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); | ||
| 1009 | }); | ||
| 1010 | } | ||
| 1011 | |||
| 1012 | static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, | ||
| 1013 | float *dst, const int ncols, | ||
| 1014 | const int nrows, | ||
| 1015 | dpct::queue_ptr stream) { | ||
| 1016 | GGML_ASSERT(ncols % QK_K == 0); | ||
| 1017 | const int ny = 2 / K_QUANTS_PER_ITERATION; | ||
| 1018 | const int block_num_y = (nrows + ny - 1) / ny; | ||
| 1019 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 1020 | const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||
| 1021 | stream->parallel_for( | ||
| 1022 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 1023 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||
| 1024 | dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); | ||
| 1025 | }); | ||
| 1026 | } | ||
| 1027 | |||
| 1028 | static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, | ||
| 1029 | float *dst, const int ncols, | ||
| 1030 | const int nrows, | ||
| 1031 | dpct::queue_ptr stream) { | ||
| 1032 | GGML_ASSERT(ncols % QK_K == 0); | ||
| 1033 | const int ny = 2 / K_QUANTS_PER_ITERATION; | ||
| 1034 | const int block_num_y = (nrows + ny - 1) / ny; | ||
| 1035 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 1036 | const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||
| 1037 | stream->parallel_for( | ||
| 1038 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 1039 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||
| 1040 | dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); | ||
| 1041 | }); | ||
| 1042 | } | ||
| 1043 | |||
| 1044 | static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, | ||
| 1045 | float *dst, const int ncols, | ||
| 1046 | const int nrows, | ||
| 1047 | dpct::queue_ptr stream) { | ||
| 1048 | GGML_ASSERT(ncols % QK_K == 0); | ||
| 1049 | const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); | ||
| 1050 | stream->parallel_for( | ||
| 1051 | sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), | ||
| 1052 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||
| 1053 | dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); | ||
| 1054 | }); | ||
| 1055 | } | ||
| 1056 | |||
| 1057 | static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, | ||
| 1058 | float *dst, const int ncols, | ||
| 1059 | const int nrows, | ||
| 1060 | dpct::queue_ptr stream) { | ||
| 1061 | GGML_ASSERT(ncols % QK_K == 0); | ||
| 1062 | const int ny = 2 / K_QUANTS_PER_ITERATION; | ||
| 1063 | const int block_num_y = (nrows + ny - 1) / ny; | ||
| 1064 | const sycl::range<3> block_nums(1, 1, block_num_y); | ||
| 1065 | const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||
| 1066 | stream->parallel_for( | ||
| 1067 | sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||
| 1068 | [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||
| 1069 | dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); | ||
| 1070 | }); | ||
| 1071 | } | ||
| 1072 | |||
| 1073 | void ggml_sycl_op_dequantize_mul_mat_vec( | ||
| 1074 | ggml_backend_sycl_context & ctx, | ||
| 1075 | const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, | ||
| 1076 | const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, | ||
| 1077 | float *dst_dd_i, const int64_t row_low, const int64_t row_high, | ||
| 1078 | const int64_t src1_ncols, const int64_t src1_padded_row_size, | ||
| 1079 | const dpct::queue_ptr &stream) { | ||
| 1080 | |||
| 1081 | const int64_t ne00 = src0->ne[0]; | ||
| 1082 | const int64_t row_diff = row_high - row_low; | ||
| 1083 | GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||
| 1084 | // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics | ||
| 1085 | #ifdef GGML_SYCL_F16 | ||
| 1086 | ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool()); | ||
| 1087 | sycl::half *src1_dfloat = nullptr; // dfloat == half | ||
| 1088 | |||
| 1089 | bool src1_convert_f16 = | ||
| 1090 | src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || | ||
| 1091 | src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || | ||
| 1092 | src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; | ||
| 1093 | |||
| 1094 | if (src1_convert_f16) { | ||
| 1095 | scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, | ||
| 1096 | " : converting src1 to fp16"); | ||
| 1097 | src1_dfloat = src1_dfloat_a.alloc(ne00); | ||
| 1098 | const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); | ||
| 1099 | GGML_ASSERT(to_fp16_sycl != nullptr); | ||
| 1100 | to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream); | ||
| 1101 | } | ||
| 1102 | #else | ||
| 1103 | const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion | ||
| 1104 | #endif // GGML_SYCL_F16 | ||
| 1105 | |||
| 1106 | switch (src0->type) { | ||
| 1107 | case GGML_TYPE_Q4_0: | ||
| 1108 | if ((ggml_tensor_extra_gpu*)dst->src[0]->extra && | ||
| 1109 | ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { | ||
| 1110 | dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1111 | } else { | ||
| 1112 | dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1113 | } | ||
| 1114 | break; | ||
| 1115 | case GGML_TYPE_Q4_1: | ||
| 1116 | dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1117 | break; | ||
| 1118 | case GGML_TYPE_Q5_0: | ||
| 1119 | dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1120 | break; | ||
| 1121 | case GGML_TYPE_Q5_1: | ||
| 1122 | dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1123 | break; | ||
| 1124 | case GGML_TYPE_Q8_0: | ||
| 1125 | dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1126 | break; | ||
| 1127 | case GGML_TYPE_Q2_K: | ||
| 1128 | dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); | ||
| 1129 | break; | ||
| 1130 | case GGML_TYPE_Q3_K: | ||
| 1131 | dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); | ||
| 1132 | break; | ||
| 1133 | case GGML_TYPE_Q4_K: | ||
| 1134 | if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && | ||
| 1135 | ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { | ||
| 1136 | // reorder is currently not supported for dmmv | ||
| 1137 | GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); | ||
| 1138 | } else { | ||
| 1139 | dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); | ||
| 1140 | } | ||
| 1141 | break; | ||
| 1142 | case GGML_TYPE_Q5_K: | ||
| 1143 | dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); | ||
| 1144 | break; | ||
| 1145 | case GGML_TYPE_Q6_K: | ||
| 1146 | dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); | ||
| 1147 | break; | ||
| 1148 | case GGML_TYPE_F16: | ||
| 1149 | convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); | ||
| 1150 | break; | ||
| 1151 | default: | ||
| 1152 | printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); | ||
| 1153 | GGML_ABORT("fatal error"); | ||
| 1154 | } | ||
| 1155 | |||
| 1156 | GGML_UNUSED(src1); | ||
| 1157 | GGML_UNUSED(dst); | ||
| 1158 | GGML_UNUSED(src1_ddq_i); | ||
| 1159 | GGML_UNUSED(src1_ncols); | ||
| 1160 | GGML_UNUSED(src1_padded_row_size); | ||
| 1161 | GGML_UNUSED(ctx); | ||
| 1162 | } | ||
