1#include "cpy.hpp"
  2
  3#include <float.h>
  4
  5#include "dequantize.hpp"
  6#include "ggml-sycl/common.hpp"
  7#include "ggml-sycl/presets.hpp"
  8#include "ggml.h"
  9
 10
 11static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
 12    const float * xi   = (const float *) cxi;
 13    float *       dsti = (float *) cdsti;
 14
 15    *dsti = *xi;
 16}
 17
 18static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
 19    const float * xi   = (const float *) cxi;
 20    sycl::half *  dsti = (sycl::half *) cdsti;
 21
 22    *dsti = sycl::vec<float, 1>(*xi).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
 23}
 24
 25static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
 26    const sycl::half * xi   = (const sycl::half *) cxi;
 27    sycl::half *       dsti = (sycl::half *) cdsti;
 28
 29    *dsti = *xi;
 30}
 31
 32static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
 33    const sycl::half * xi   = (const sycl::half *) cxi;
 34    float *            dsti = (float *) cdsti;
 35
 36    *dsti = *xi;
 37}
 38
 39static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
 40    const int16_t * xi   = (const int16_t *) cxi;
 41    int16_t *       dsti = (int16_t *) cdsti;
 42
 43    *dsti = *xi;
 44}
 45
 46static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
 47    const int32_t * xi   = (const int32_t *) cxi;
 48    int32_t *       dsti = (int32_t *) cdsti;
 49
 50    *dsti = *xi;
 51}
 52
 53template <cpy_kernel_t cpy_1>
 54static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
 55                        const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
 56                        const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
 57                        const sycl::nd_item<3> & item_ct1) {
 58    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 59
 60    if (i >= ne) {
 61        return;
 62    }
 63
 64    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
 65    // then combine those indices with the corresponding byte offsets to get the total offsets
 66    const int i03      = i / (ne00 * ne01 * ne02);
 67    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
 68    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
 69    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
 70    const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
 71
 72    const int i13        = i / (ne10 * ne11 * ne12);
 73    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
 74    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
 75    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
 76    const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
 77
 78    cpy_1(cx + x_offset, cdst + dst_offset);
 79}
 80
 81
 82/* quantized type same copy */
 83template<typename T>
 84static void cpy_blck_q_q(const char * cxi, char * cdsti) {
 85    const T * xi = (const T *) cxi;
 86    T * dsti = (T *) cdsti;
 87    *dsti = *xi;
 88}
 89
 90
 91static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
 92    float * cdstf = (float *) (cdsti);
 93
 94    for (int j = 0; j < QK8_0; j += 2) {
 95        dfloat2 dq;
 96        dequantize_q8_0(cxi, 0, j, dq);
 97        *(cdstf + j)     = dq.x();
 98        *(cdstf + j + 1) = dq.y();
 99    }
100}
101
102
103
104template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
105    float * cdstf = (float *) (cdsti);
106
107    for (int j = 0; j < qk / 2; j++) {
108        dfloat2 dq;
109        dequant(cxi, 0, j, dq);
110        *(cdstf + j)          = dq.x();
111        *(cdstf + j + qk / 2) = dq.y();
112    }
113}
114
115
116template <typename T, int qk>
117static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
118                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
119                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
120                      const sycl::nd_item<3> & item_ct1) {
121    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
122
123    if (i >= ne) {
124        return;
125    }
126
127    const int i03      = i / (ne00 * ne01 * ne02);
128    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
129    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
130    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
131    const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
132
133
134    const int i13        = i / (ne10 * ne11 * ne12);
135    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
136    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
137    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
138    const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
139
140    cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
141}
142
143template <cpy_kernel_t cpy_blck, int qk>
144static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
145                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
146                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
147                      const sycl::nd_item<3> & item_ct1) {
148    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
149
150    if (i >= ne) {
151        return;
152    }
153
154
155    const int i03      = i / (ne00 * ne01 * ne02);
156    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
157    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
158    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
159    const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
160
161    const int i13        = i / (ne10 * ne11 * ne12);
162    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
163    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
164    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
165    const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
166
167    cpy_blck(cx + x_offset, cdst + dst_offset);
168}
169
170template <cpy_kernel_t cpy_blck, int qk>
171static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
172                      const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
173                      const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
174                      const sycl::nd_item<3> & item_ct1) {
175    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
176
177    if (i >= ne) {
178        return;
179    }
180
181    const int i03      = i / (ne00 * ne01 * ne02);
182    const int i02      = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
183    const int i01      = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
184    const int i00      = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
185    const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
186
187    const int i13        = i / (ne10 * ne11 * ne12);
188    const int i12        = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
189    const int i11        = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
190    const int i10        = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
191    const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
192
193    cpy_blck(cx + x_offset, cdst + dst_offset);
194}
195
196static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
197                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
198                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
199                                  const int nb12, const int nb13, queue_ptr stream) {
200    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
201    {
202        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
203
204        stream->parallel_for(
205            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
206                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
207            [=](sycl::nd_item<3> item_ct1) {
208                cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
209                                           nb10, nb11, nb12, nb13, item_ct1);
210            });
211    }
212}
213
214static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
215                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
216                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
217                                  const int nb12, const int nb13, queue_ptr stream) {
218    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
219    {
220        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
221
222        stream->parallel_for(
223            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
224                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
225            [=](sycl::nd_item<3> item_ct1) {
226                cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
227                                           nb10, nb11, nb12, nb13, item_ct1);
228            });
229    }
230}
231
232static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
233                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
234                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
235                                  const int nb12, const int nb13, queue_ptr stream) {
236    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
237    {
238        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
239
240        stream->parallel_for(
241            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
242                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
243            [=](sycl::nd_item<3> item_ct1) {
244                cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
245                                           nb10, nb11, nb12, nb13, item_ct1);
246            });
247    }
248}
249
250static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
251                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
252                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
253                                   const int nb12, const int nb13, queue_ptr stream) {
254    GGML_ASSERT(ne % QK8_0 == 0);
255    const int num_blocks = ne / QK8_0;
256    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
257                         [=](sycl::nd_item<3> item_ct1) {
258                             cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
259                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
260                         });
261}
262
263static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
264                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
265                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
266                                   const int nb12, const int nb13, queue_ptr stream) {
267    const int num_blocks = ne;
268    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
269                         [=](sycl::nd_item<3> item_ct1) {
270                             cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
271                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
272                         });
273}
274
275static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
276                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
277                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
278                                   const int nb12, const int nb13, queue_ptr stream) {
279    GGML_ASSERT(ne % QK4_0 == 0);
280    const int num_blocks = ne / QK4_0;
281    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
282                         [=](sycl::nd_item<3> item_ct1) {
283                             cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
284                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
285                         });
286}
287
288static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
289                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
290                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
291                                   const int nb12, const int nb13, queue_ptr stream) {
292    const int num_blocks = ne;
293    stream->parallel_for(
294        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
295            cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
296                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
297                                                                     item_ct1);
298        });
299}
300
301static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
302                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
303                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
304                                   const int nb12, const int nb13, queue_ptr stream) {
305    GGML_ASSERT(ne % QK4_1 == 0);
306    const int num_blocks = ne / QK4_1;
307    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
308                         [=](sycl::nd_item<3> item_ct1) {
309                             cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
310                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
311                         });
312}
313
314static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
315                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
316                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
317                                   const int nb12, const int nb13, queue_ptr stream) {
318    const int num_blocks = ne;
319    stream->parallel_for(
320        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
321            cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
322                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
323                                                                     item_ct1);
324        });
325}
326
327static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
328                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
329                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
330                                   const int nb12, const int nb13, queue_ptr stream) {
331    GGML_ASSERT(ne % QK5_0 == 0);
332    const int num_blocks = ne / QK5_0;
333    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
334                         [=](sycl::nd_item<3> item_ct1) {
335                             cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
336                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
337                         });
338}
339
340static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
341                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
342                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
343                                   const int nb12, const int nb13, queue_ptr stream) {
344    const int num_blocks = ne;
345    stream->parallel_for(
346        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
347            cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
348                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
349                                                                     item_ct1);
350        });
351}
352
353static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
354                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
355                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
356                                   const int nb12, const int nb13, queue_ptr stream) {
357    GGML_ASSERT(ne % QK5_1 == 0);
358    const int num_blocks = ne / QK5_1;
359    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
360                         [=](sycl::nd_item<3> item_ct1) {
361                             cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
362                                                                 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
363                         });
364}
365
366static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
367                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
368                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
369                                   const int nb12, const int nb13, queue_ptr stream) {
370    const int num_blocks = ne;
371    stream->parallel_for(
372        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
373            cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
374                                                                     nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
375                                                                     item_ct1);
376        });
377}
378
379static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
380                                     const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
381                                     const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
382                                     const int nb12, const int nb13, queue_ptr stream) {
383    GGML_ASSERT(ne % QK4_NL == 0);
384    const int num_blocks = ne / QK4_NL;
385    stream->parallel_for(
386        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
387            cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
388                                                   ne12, nb10, nb11, nb12, nb13, item_ct1);
389        });
390}
391
392static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
393                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
394                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
395                                  const int nb12, const int nb13, queue_ptr stream) {
396    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
397    {
398        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
399
400        stream->parallel_for(
401            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
402                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
403            [=](sycl::nd_item<3> item_ct1) {
404                cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
405                                           nb10, nb11, nb12, nb13, item_ct1);
406            });
407    }
408}
409
410static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
411                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
412                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
413                                  const int nb12, const int nb13, queue_ptr stream) {
414    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
415    {
416        // dpct::has_capability_or_fail(stream->get_device(),
417        //                              {sycl::aspect::fp16});
418
419        stream->parallel_for(
420            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
421                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
422            [=](sycl::nd_item<3> item_ct1) {
423                cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
424                                           nb10, nb11, nb12, nb13, item_ct1);
425            });
426    }
427}
428
429static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
430                                  const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
431                                  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
432                                  const int nb12, const int nb13, queue_ptr stream) {
433    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
434    {
435        // dpct::has_capability_or_fail(stream->get_device(),
436        //                              {sycl::aspect::fp16});
437
438        stream->parallel_for(
439            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
440                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
441            [=](sycl::nd_item<3> item_ct1) {
442                cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
443                                           nb10, nb11, nb12, nb13, item_ct1);
444            });
445    }
446}
447
448static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
449                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
450                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
451                                   const int nb12, const int nb13, queue_ptr stream) {
452    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
453    stream->parallel_for(
454        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
455                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
456            cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
457        });
458}
459
460
461static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
462                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
463                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
464                                   const int nb12, const int nb13, queue_ptr stream) {
465    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
466    stream->parallel_for(
467        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
468                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
469            cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
470        });
471}
472
473
474static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
475                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
476                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
477                                   const int nb12, const int nb13, queue_ptr stream) {
478    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
479
480    stream->parallel_for(
481        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
482                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
483            cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
484        });
485}
486
487
488static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
489                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
490                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
491                                   const int nb12, const int nb13, queue_ptr stream) {
492    const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
493    stream->parallel_for(
494        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
495            cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
496        });
497}
498
499
500static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
501                                   const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
502                                   const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
503                                   const int nb12, const int nb13, queue_ptr stream) {
504
505   const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
506   stream->parallel_for(
507        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
508            cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
509        });
510}
511
512void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
513    // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
514    scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0));
515    const int64_t ne = ggml_nelements(src0);
516    GGML_ASSERT(ne == ggml_nelements(src1));
517
518    GGML_TENSOR_BINARY_OP_LOCALS01;
519
520    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
521    queue_ptr main_stream = ctx.stream();
522
523    char * src0_ddc = (char *) src0->data;
524    char * src1_ddc = (char *) src1->data;
525    if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
526        GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
527        main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
528    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
529        ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
530                              nb11, nb12, nb13, main_stream);
531    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
532        ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
533                              nb11, nb12, nb13, main_stream);
534    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
535        ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
536                               nb11, nb12, nb13, main_stream);
537    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
538        ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
539                               nb11, nb12, nb13, main_stream);
540    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
541        ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
542                               nb11, nb12, nb13, main_stream);
543    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
544        ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
545                              nb11, nb12, nb13, main_stream);
546    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
547        ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
548                              nb11, nb12, nb13, main_stream);
549    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
550        ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
551                              nb11, nb12, nb13, main_stream);
552    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
553        ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
554                              nb11, nb12, nb13, main_stream);
555    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
556        ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
557                               nb11, nb12, nb13, main_stream);
558    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
559        ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
560                               nb11, nb12, nb13, main_stream);
561    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
562        ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
563                               nb11, nb12, nb13, main_stream);
564    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
565        ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
566                               nb11, nb12, nb13, main_stream);
567    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568        ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
569                               nb11, nb12, nb13, main_stream);
570    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
571        ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
572                               nb11, nb12, nb13, main_stream);
573    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
574        ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
575                               nb11, nb12, nb13, main_stream);
576    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
577        ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
578                                 nb10, nb11, nb12, nb13, main_stream);
579    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
580        ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
581    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
582        ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
583    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
584        ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
585    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
586        ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
587    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
588        ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
589    } else {
590        GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
591                       ggml_type_name(src1->type));
592        GGML_ABORT("fatal error");
593    }
594} catch (const sycl::exception & exc) {
595    std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
596    std::exit(1);
597}
598
599void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
600    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
601    ggml_sycl_cpy(ctx, dst->src[0], dst);
602}