1#include "cpy.hpp"
2
3#include <float.h>
4
5#include "dequantize.hpp"
6#include "ggml-sycl/common.hpp"
7#include "ggml-sycl/presets.hpp"
8#include "ggml.h"
9
10
11static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
12 const float * xi = (const float *) cxi;
13 float * dsti = (float *) cdsti;
14
15 *dsti = *xi;
16}
17
18static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
19 const float * xi = (const float *) cxi;
20 sycl::half * dsti = (sycl::half *) cdsti;
21
22 *dsti = sycl::vec<float, 1>(*xi).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
23}
24
25static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
26 const sycl::half * xi = (const sycl::half *) cxi;
27 sycl::half * dsti = (sycl::half *) cdsti;
28
29 *dsti = *xi;
30}
31
32static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
33 const sycl::half * xi = (const sycl::half *) cxi;
34 float * dsti = (float *) cdsti;
35
36 *dsti = *xi;
37}
38
39static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
40 const int16_t * xi = (const int16_t *) cxi;
41 int16_t * dsti = (int16_t *) cdsti;
42
43 *dsti = *xi;
44}
45
46static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
47 const int32_t * xi = (const int32_t *) cxi;
48 int32_t * dsti = (int32_t *) cdsti;
49
50 *dsti = *xi;
51}
52
53template <cpy_kernel_t cpy_1>
54static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
55 const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
56 const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
57 const sycl::nd_item<3> & item_ct1) {
58 const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
59
60 if (i >= ne) {
61 return;
62 }
63
64 // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
65 // then combine those indices with the corresponding byte offsets to get the total offsets
66 const int i03 = i / (ne00 * ne01 * ne02);
67 const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
68 const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
69 const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
70 const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
71
72 const int i13 = i / (ne10 * ne11 * ne12);
73 const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
74 const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
75 const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
76 const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
77
78 cpy_1(cx + x_offset, cdst + dst_offset);
79}
80
81
82/* quantized type same copy */
83template<typename T>
84static void cpy_blck_q_q(const char * cxi, char * cdsti) {
85 const T * xi = (const T *) cxi;
86 T * dsti = (T *) cdsti;
87 *dsti = *xi;
88}
89
90
91static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
92 float * cdstf = (float *) (cdsti);
93
94 for (int j = 0; j < QK8_0; j += 2) {
95 dfloat2 dq;
96 dequantize_q8_0(cxi, 0, j, dq);
97 *(cdstf + j) = dq.x();
98 *(cdstf + j + 1) = dq.y();
99 }
100}
101
102
103
104template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
105 float * cdstf = (float *) (cdsti);
106
107 for (int j = 0; j < qk / 2; j++) {
108 dfloat2 dq;
109 dequant(cxi, 0, j, dq);
110 *(cdstf + j) = dq.x();
111 *(cdstf + j + qk / 2) = dq.y();
112 }
113}
114
115
116template <typename T, int qk>
117static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
118 const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
119 const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
120 const sycl::nd_item<3> & item_ct1) {
121 const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
122
123 if (i >= ne) {
124 return;
125 }
126
127 const int i03 = i / (ne00 * ne01 * ne02);
128 const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
129 const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
130 const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
131 const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
132
133
134 const int i13 = i / (ne10 * ne11 * ne12);
135 const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
136 const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
137 const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
138 const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
139
140 cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
141}
142
143template <cpy_kernel_t cpy_blck, int qk>
144static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
145 const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
146 const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
147 const sycl::nd_item<3> & item_ct1) {
148 const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
149
150 if (i >= ne) {
151 return;
152 }
153
154
155 const int i03 = i / (ne00 * ne01 * ne02);
156 const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
157 const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
158 const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
159 const int x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
160
161 const int i13 = i / (ne10 * ne11 * ne12);
162 const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
163 const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
164 const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
165 const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
166
167 cpy_blck(cx + x_offset, cdst + dst_offset);
168}
169
170template <cpy_kernel_t cpy_blck, int qk>
171static void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
172 const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
173 const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
174 const sycl::nd_item<3> & item_ct1) {
175 const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
176
177 if (i >= ne) {
178 return;
179 }
180
181 const int i03 = i / (ne00 * ne01 * ne02);
182 const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
183 const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
184 const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
185 const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
186
187 const int i13 = i / (ne10 * ne11 * ne12);
188 const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
189 const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
190 const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
191 const int dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
192
193 cpy_blck(cx + x_offset, cdst + dst_offset);
194}
195
196static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
197 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
198 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
199 const int nb12, const int nb13, queue_ptr stream) {
200 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
201 {
202 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
203
204 stream->parallel_for(
205 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
206 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
207 [=](sycl::nd_item<3> item_ct1) {
208 cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
209 nb10, nb11, nb12, nb13, item_ct1);
210 });
211 }
212}
213
214static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
215 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
216 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
217 const int nb12, const int nb13, queue_ptr stream) {
218 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
219 {
220 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
221
222 stream->parallel_for(
223 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
224 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
225 [=](sycl::nd_item<3> item_ct1) {
226 cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
227 nb10, nb11, nb12, nb13, item_ct1);
228 });
229 }
230}
231
232static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
233 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
234 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
235 const int nb12, const int nb13, queue_ptr stream) {
236 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
237 {
238 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
239
240 stream->parallel_for(
241 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
242 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
243 [=](sycl::nd_item<3> item_ct1) {
244 cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
245 nb10, nb11, nb12, nb13, item_ct1);
246 });
247 }
248}
249
250static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
251 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
252 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
253 const int nb12, const int nb13, queue_ptr stream) {
254 GGML_ASSERT(ne % QK8_0 == 0);
255 const int num_blocks = ne / QK8_0;
256 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
257 [=](sycl::nd_item<3> item_ct1) {
258 cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
259 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
260 });
261}
262
263static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
264 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
265 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
266 const int nb12, const int nb13, queue_ptr stream) {
267 const int num_blocks = ne;
268 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
269 [=](sycl::nd_item<3> item_ct1) {
270 cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
271 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
272 });
273}
274
275static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
276 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
277 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
278 const int nb12, const int nb13, queue_ptr stream) {
279 GGML_ASSERT(ne % QK4_0 == 0);
280 const int num_blocks = ne / QK4_0;
281 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
282 [=](sycl::nd_item<3> item_ct1) {
283 cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
284 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
285 });
286}
287
288static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
289 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
290 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
291 const int nb12, const int nb13, queue_ptr stream) {
292 const int num_blocks = ne;
293 stream->parallel_for(
294 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
295 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
296 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
297 item_ct1);
298 });
299}
300
301static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
302 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
303 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
304 const int nb12, const int nb13, queue_ptr stream) {
305 GGML_ASSERT(ne % QK4_1 == 0);
306 const int num_blocks = ne / QK4_1;
307 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
308 [=](sycl::nd_item<3> item_ct1) {
309 cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
310 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
311 });
312}
313
314static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
315 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
316 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
317 const int nb12, const int nb13, queue_ptr stream) {
318 const int num_blocks = ne;
319 stream->parallel_for(
320 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
321 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
322 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
323 item_ct1);
324 });
325}
326
327static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
328 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
329 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
330 const int nb12, const int nb13, queue_ptr stream) {
331 GGML_ASSERT(ne % QK5_0 == 0);
332 const int num_blocks = ne / QK5_0;
333 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
334 [=](sycl::nd_item<3> item_ct1) {
335 cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
336 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
337 });
338}
339
340static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
341 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
342 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
343 const int nb12, const int nb13, queue_ptr stream) {
344 const int num_blocks = ne;
345 stream->parallel_for(
346 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
347 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
348 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
349 item_ct1);
350 });
351}
352
353static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
354 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
355 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
356 const int nb12, const int nb13, queue_ptr stream) {
357 GGML_ASSERT(ne % QK5_1 == 0);
358 const int num_blocks = ne / QK5_1;
359 stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
360 [=](sycl::nd_item<3> item_ct1) {
361 cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
362 ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
363 });
364}
365
366static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
367 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
368 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
369 const int nb12, const int nb13, queue_ptr stream) {
370 const int num_blocks = ne;
371 stream->parallel_for(
372 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
373 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
374 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
375 item_ct1);
376 });
377}
378
379static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
380 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
381 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
382 const int nb12, const int nb13, queue_ptr stream) {
383 GGML_ASSERT(ne % QK4_NL == 0);
384 const int num_blocks = ne / QK4_NL;
385 stream->parallel_for(
386 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
387 cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
388 ne12, nb10, nb11, nb12, nb13, item_ct1);
389 });
390}
391
392static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
393 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
394 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
395 const int nb12, const int nb13, queue_ptr stream) {
396 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
397 {
398 dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
399
400 stream->parallel_for(
401 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
402 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
403 [=](sycl::nd_item<3> item_ct1) {
404 cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
405 nb10, nb11, nb12, nb13, item_ct1);
406 });
407 }
408}
409
410static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
411 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
412 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
413 const int nb12, const int nb13, queue_ptr stream) {
414 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
415 {
416 // dpct::has_capability_or_fail(stream->get_device(),
417 // {sycl::aspect::fp16});
418
419 stream->parallel_for(
420 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
421 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
422 [=](sycl::nd_item<3> item_ct1) {
423 cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
424 nb10, nb11, nb12, nb13, item_ct1);
425 });
426 }
427}
428
429static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
430 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
431 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
432 const int nb12, const int nb13, queue_ptr stream) {
433 const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
434 {
435 // dpct::has_capability_or_fail(stream->get_device(),
436 // {sycl::aspect::fp16});
437
438 stream->parallel_for(
439 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
440 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
441 [=](sycl::nd_item<3> item_ct1) {
442 cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
443 nb10, nb11, nb12, nb13, item_ct1);
444 });
445 }
446}
447
448static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
449 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
450 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
451 const int nb12, const int nb13, queue_ptr stream) {
452 const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
453 stream->parallel_for(
454 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
455 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
456 cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
457 });
458}
459
460
461static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
462 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
463 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
464 const int nb12, const int nb13, queue_ptr stream) {
465 const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
466 stream->parallel_for(
467 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
468 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
469 cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
470 });
471}
472
473
474static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
475 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
476 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
477 const int nb12, const int nb13, queue_ptr stream) {
478 const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
479
480 stream->parallel_for(
481 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
482 sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
483 cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
484 });
485}
486
487
488static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
489 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
490 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
491 const int nb12, const int nb13, queue_ptr stream) {
492 const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
493 stream->parallel_for(
494 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
495 cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
496 });
497}
498
499
500static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
501 const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
502 const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
503 const int nb12, const int nb13, queue_ptr stream) {
504
505 const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
506 stream->parallel_for(
507 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
508 cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
509 });
510}
511
512void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
513 // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
514 scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0));
515 const int64_t ne = ggml_nelements(src0);
516 GGML_ASSERT(ne == ggml_nelements(src1));
517
518 GGML_TENSOR_BINARY_OP_LOCALS01;
519
520 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
521 queue_ptr main_stream = ctx.stream();
522
523 char * src0_ddc = (char *) src0->data;
524 char * src1_ddc = (char *) src1->data;
525 if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
526 GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
527 main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
528 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
529 ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
530 nb11, nb12, nb13, main_stream);
531 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
532 ggml_cpy_f32_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
533 nb11, nb12, nb13, main_stream);
534 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
535 ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
536 nb11, nb12, nb13, main_stream);
537 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
538 ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
539 nb11, nb12, nb13, main_stream);
540 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
541 ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
542 nb11, nb12, nb13, main_stream);
543 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
544 ggml_cpy_f16_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
545 nb11, nb12, nb13, main_stream);
546 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
547 ggml_cpy_f16_f16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
548 nb11, nb12, nb13, main_stream);
549 } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
550 ggml_cpy_i16_i16_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
551 nb11, nb12, nb13, main_stream);
552 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
553 ggml_cpy_i32_i32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
554 nb11, nb12, nb13, main_stream);
555 } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
556 ggml_cpy_q4_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
557 nb11, nb12, nb13, main_stream);
558 } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
559 ggml_cpy_q4_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
560 nb11, nb12, nb13, main_stream);
561 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
562 ggml_cpy_q8_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
563 nb11, nb12, nb13, main_stream);
564 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
565 ggml_cpy_f32_q5_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
566 nb11, nb12, nb13, main_stream);
567 } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568 ggml_cpy_q5_0_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
569 nb11, nb12, nb13, main_stream);
570 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
571 ggml_cpy_f32_q5_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
572 nb11, nb12, nb13, main_stream);
573 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
574 ggml_cpy_q5_1_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
575 nb11, nb12, nb13, main_stream);
576 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
577 ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
578 nb10, nb11, nb12, nb13, main_stream);
579 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
580 ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
581 } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
582 ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
583 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
584 ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
585 } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
586 ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
587 } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
588 ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
589 } else {
590 GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
591 ggml_type_name(src1->type));
592 GGML_ABORT("fatal error");
593 }
594} catch (const sycl::exception & exc) {
595 std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
596 std::exit(1);
597}
598
599void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
600 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
601 ggml_sycl_cpy(ctx, dst->src[0], dst);
602}