aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-sycl/dpct
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/dpct
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/dpct')
-rw-r--r--llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp3002
1 files changed, 3002 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp b/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp
new file mode 100644
index 0000000..ece66a7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp
@@ -0,0 +1,3002 @@
1//
2// MIT license
3// Copyright (C) 2024 Intel Corporation
4// SPDX-License-Identifier: MIT
5//
6
7//
8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9// See https://llvm.org/LICENSE.txt for license information.
10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11//
12
13#ifndef GGML_SYCL_DPCT_HELPER_HPP
14#define GGML_SYCL_DPCT_HELPER_HPP
15
16#include <sycl/sycl.hpp>
17#include <sycl/half_type.hpp>
18#include <oneapi/mkl.hpp>
19
20#include <map>
21
22#include "ggml.h"
23
24#if defined(__linux__)
25#include <sys/mman.h>
26#elif defined(_WIN64)
27#ifndef NOMINMAX
28#define NOMINMAX
29#endif
30#include <windows.h>
31#else
32#error "Only support Windows and Linux."
33#endif
34
35#if defined(__linux__)
36#include <unistd.h>
37#include <sys/syscall.h>
38#endif
39#if defined(_WIN64)
40#ifndef NOMINMAX
41#define NOMINMAX
42#endif
43#include <windows.h>
44#endif
45
46#define DPCT_COMPATIBILITY_TEMP (900)
47
48#if defined(_MSC_VER)
49#define __dpct_align__(n) __declspec(align(n))
50#define __dpct_inline__ __forceinline
51#else
52#define __dpct_align__(n) __attribute__((aligned(n)))
53#define __dpct_inline__ __inline__ __attribute__((always_inline))
54#endif
55
56#if defined(_MSC_VER)
57#define __dpct_noinline__ __declspec(noinline)
58#else
59#define __dpct_noinline__ __attribute__((noinline))
60#endif
61
62inline std::string get_device_type_name(const sycl::device &Device) {
63 auto DeviceType = Device.get_info<sycl::info::device::device_type>();
64 switch (DeviceType) {
65 case sycl::info::device_type::cpu:
66 return "cpu";
67 case sycl::info::device_type::gpu:
68 return "gpu";
69 case sycl::info::device_type::host:
70 return "host";
71 case sycl::info::device_type::accelerator:
72 return "acc";
73 default:
74 return "unknown";
75 }
76}
77
78inline std::string get_device_backend_and_type(const sycl::device &device) {
79 std::stringstream device_type;
80 sycl::backend backend = device.get_backend();
81 device_type << backend << ":" << get_device_type_name(device);
82 return device_type.str();
83}
84
85template <typename Ts> struct matrix_info_t {
86 oneapi::mkl::transpose transpose_info[2];
87 Ts value_info[2];
88 std::int64_t size_info[3];
89 std::int64_t ld_info[3];
90 std::int64_t groupsize_info;
91};
92
93namespace dpct
94{
95 typedef sycl::queue *queue_ptr;
96 typedef sycl::event *event_ptr;
97 typedef char *device_ptr;
98 typedef uint8_t byte_t;
99 typedef sycl::buffer<byte_t> buffer_t;
100
101 /// SYCL default exception handler
102 inline auto exception_handler = [](sycl::exception_list exceptions)
103 {
104 for (std::exception_ptr const &e : exceptions)
105 {
106 try
107 {
108 std::rethrow_exception(e);
109 }
110 catch (sycl::exception const &e)
111 {
112 std::cerr << "Caught asynchronous SYCL exception:" << std::endl
113 << e.what() << std::endl
114 << "Exception caught at file:" << __FILE__
115 << ", line:" << __LINE__ << std::endl;
116 }
117 }
118 };
119
120 enum error_code
121 {
122 success = 0,
123 default_error = 999
124 };
125
126 enum memcpy_direction
127 {
128 host_to_host,
129 host_to_device,
130 device_to_host,
131 device_to_device,
132 automatic
133 };
134
135 enum memory_region
136 {
137 global = 0, // device global memory
138 constant, // device constant memory
139 local, // device local memory
140 shared, // memory which can be accessed by host and device
141 };
142
143 enum class library_data_t : unsigned char
144 {
145 real_float = 0,
146 complex_float,
147 real_double,
148 complex_double,
149 real_half,
150 complex_half,
151 real_bfloat16,
152 complex_bfloat16,
153 real_int4,
154 complex_int4,
155 real_uint4,
156 complex_uint4,
157 real_int8,
158 complex_int8,
159 real_uint8,
160 complex_uint8,
161 real_int16,
162 complex_int16,
163 real_uint16,
164 complex_uint16,
165 real_int32,
166 complex_int32,
167 real_uint32,
168 complex_uint32,
169 real_int64,
170 complex_int64,
171 real_uint64,
172 complex_uint64,
173 real_int8_4,
174 real_int8_32,
175 real_uint8_4,
176 library_data_t_size
177 };
178
179 template <typename T>
180 struct DataType
181 {
182 using T2 = T;
183 };
184 template <typename T>
185 struct DataType<sycl::vec<T, 2>>
186 {
187 using T2 = std::complex<T>;
188 };
189
190 static void destroy_event(event_ptr event)
191 {
192 delete event;
193 }
194
195 static inline unsigned int get_tid()
196 {
197#if defined(__linux__)
198 return syscall(SYS_gettid);
199#elif defined(_WIN64)
200 return GetCurrentThreadId();
201#else
202#error "Only support Windows and Linux."
203#endif
204 }
205
206 namespace detail
207 {
208 static void get_version(const sycl::device &dev, int &major, int &minor)
209 {
210 // Version string has the following format:
211 // a. OpenCL<space><major.minor><space><vendor-specific-information>
212 // b. <major.minor>
213 // c. <AmdGcnArchName> e.g gfx1030
214 std::string ver;
215 ver = dev.get_info<sycl::info::device::version>();
216 std::string::size_type i = 0;
217 while (i < ver.size()) {
218 if (isdigit(ver[i]))
219 break;
220 i++;
221 }
222 major = std::stoi(&(ver[i]));
223 while (i < ver.size()) {
224 if (ver[i] == '.')
225 break;
226 i++;
227 }
228 if (i < ver.size()) {
229 // a. and b.
230 i++;
231 minor = std::stoi(&(ver[i]));
232 } else {
233 // c.
234 minor = 0;
235 }
236 }
237
238 template <typename tag, typename T>
239 class generic_error_type
240 {
241 public:
242 generic_error_type() = default;
243 generic_error_type(T value) : value{value} {}
244 operator T() const { return value; }
245
246 private:
247 T value;
248 };
249
250 } // namespace detail
251
252 // COPY from DPCT head files
253 /// dim3 is used to store 3 component dimensions.
254 class dim3 {
255 public:
256 unsigned x, y, z;
257
258 constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
259 : x(x), y(y), z(z) {}
260
261 dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
262
263 operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
264 }; // namespace dim3
265
266 inline dim3 operator*(const dim3 &a, const dim3 &b) {
267 return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
268 }
269 // COPY from DPCT head files
270
271
272 /// Pitched 2D/3D memory data.
273 class pitched_data
274 {
275 public:
276 pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
277 pitched_data(void *data, size_t pitch, size_t x, size_t y)
278 : _data(data), _pitch(pitch), _x(x), _y(y) {}
279
280 void *get_data_ptr() { return _data; }
281 void set_data_ptr(void *data) { _data = data; }
282
283 size_t get_pitch() { return _pitch; }
284 void set_pitch(size_t pitch) { _pitch = pitch; }
285
286 size_t get_x() { return _x; }
287 void set_x(size_t x) { _x = x; }
288
289 size_t get_y() { return _y; }
290 void set_y(size_t y) { _y = y; }
291
292 private:
293 void *_data;
294 size_t _pitch, _x, _y;
295 };
296
297 class device_info
298 {
299 public:
300 // get interface
301 const char *get_name() const { return _name; }
302 char *get_name() { return _name; }
303 template <typename WorkItemSizesTy = sycl::range<3>,
304 std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
305 std::is_same_v<WorkItemSizesTy, int *>,
306 int> = 0>
307 auto get_max_work_item_sizes() const
308 {
309 if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
310 return sycl::range<3>(_max_work_item_sizes_i[0],
311 _max_work_item_sizes_i[1],
312 _max_work_item_sizes_i[2]);
313 else
314 {
315 return _max_work_item_sizes_i;
316 }
317 }
318 template <typename WorkItemSizesTy = sycl::range<3>,
319 std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
320 std::is_same_v<WorkItemSizesTy, int *>,
321 int> = 0>
322 auto get_max_work_item_sizes()
323 {
324 if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
325 return sycl::range<3>(_max_work_item_sizes_i[0],
326 _max_work_item_sizes_i[1],
327 _max_work_item_sizes_i[2]);
328 else
329 {
330 return _max_work_item_sizes_i;
331 }
332 }
333 bool get_host_unified_memory() const { return _host_unified_memory; }
334 int get_major_version() const { return _major; }
335 int get_minor_version() const { return _minor; }
336 int get_integrated() const { return _integrated; }
337 int get_max_clock_frequency() const { return _frequency; }
338 int get_max_compute_units() const { return _max_compute_units; }
339 int get_max_work_group_size() const { return _max_work_group_size; }
340 int get_max_sub_group_size() const { return _max_sub_group_size; }
341 int get_max_work_items_per_compute_unit() const
342 {
343 return _max_work_items_per_compute_unit;
344 }
345 int get_max_register_size_per_work_group() const
346 {
347 return _max_register_size_per_work_group;
348 }
349 template <typename NDRangeSizeTy = size_t *,
350 std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
351 std::is_same_v<NDRangeSizeTy, int *>,
352 int> = 0>
353 auto get_max_nd_range_size() const
354 {
355 if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
356 return _max_nd_range_size;
357 else
358 return _max_nd_range_size_i;
359 }
360 template <typename NDRangeSizeTy = size_t *,
361 std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
362 std::is_same_v<NDRangeSizeTy, int *>,
363 int> = 0>
364 auto get_max_nd_range_size()
365 {
366 if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
367 return _max_nd_range_size;
368 else
369 return _max_nd_range_size_i;
370 }
371 size_t get_global_mem_size() const { return _global_mem_size; }
372 size_t get_local_mem_size() const { return _local_mem_size; }
373 size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
374 /// Returns the maximum clock rate of device's global memory in kHz. If
375 /// compiler does not support this API then returns default value 3200000 kHz.
376 unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
377 /// Returns the maximum bus width between device and memory in bits. If
378 /// compiler does not support this API then returns default value 64 bits.
379 unsigned int get_memory_bus_width() const { return _memory_bus_width; }
380 uint32_t get_device_id() const { return _device_id; }
381 std::array<unsigned char, 16> get_uuid() const { return _uuid; }
382 /// Returns global memory cache size in bytes.
383 unsigned int get_global_mem_cache_size() const
384 {
385 return _global_mem_cache_size;
386 }
387
388 // set interface
389 void set_name(const char *name)
390 {
391 size_t length = strlen(name);
392 if (length < 256)
393 {
394 std::memcpy(_name, name, length + 1);
395 }
396 else
397 {
398 std::memcpy(_name, name, 255);
399 _name[255] = '\0';
400 }
401 }
402 void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
403 {
404 for (int i = 0; i < 3; ++i)
405 _max_work_item_sizes_i[i] = max_work_item_sizes[i];
406 }
407 [[deprecated]] void
408 set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
409 {
410 for (int i = 0; i < 3; ++i)
411 {
412 _max_work_item_sizes_i[i] = max_work_item_sizes[i];
413 }
414 }
415 void set_host_unified_memory(bool host_unified_memory)
416 {
417 _host_unified_memory = host_unified_memory;
418 }
419 void set_major_version(int major) { _major = major; }
420 void set_minor_version(int minor) { _minor = minor; }
421 void set_integrated(int integrated) { _integrated = integrated; }
422 void set_max_clock_frequency(int frequency) { _frequency = frequency; }
423 void set_max_compute_units(int max_compute_units)
424 {
425 _max_compute_units = max_compute_units;
426 }
427 void set_global_mem_size(size_t global_mem_size)
428 {
429 _global_mem_size = global_mem_size;
430 }
431 void set_local_mem_size(size_t local_mem_size)
432 {
433 _local_mem_size = local_mem_size;
434 }
435 void set_max_mem_alloc_size(size_t max_mem_alloc_size)
436 {
437 _max_mem_alloc_size = max_mem_alloc_size;
438 }
439 void set_max_work_group_size(int max_work_group_size)
440 {
441 _max_work_group_size = max_work_group_size;
442 }
443 void set_max_sub_group_size(int max_sub_group_size)
444 {
445 _max_sub_group_size = max_sub_group_size;
446 }
447 void
448 set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
449 {
450 _max_work_items_per_compute_unit = max_work_items_per_compute_unit;
451 }
452 void set_max_nd_range_size(int max_nd_range_size[])
453 {
454 for (int i = 0; i < 3; i++)
455 {
456 _max_nd_range_size[i] = max_nd_range_size[i];
457 _max_nd_range_size_i[i] = max_nd_range_size[i];
458 }
459 }
460 void set_memory_clock_rate(unsigned int memory_clock_rate)
461 {
462 _memory_clock_rate = memory_clock_rate;
463 }
464 void set_memory_bus_width(unsigned int memory_bus_width)
465 {
466 _memory_bus_width = memory_bus_width;
467 }
468 void
469 set_max_register_size_per_work_group(int max_register_size_per_work_group)
470 {
471 _max_register_size_per_work_group = max_register_size_per_work_group;
472 }
473 void set_device_id(uint32_t device_id)
474 {
475 _device_id = device_id;
476 }
477 void set_uuid(std::array<unsigned char, 16> uuid)
478 {
479 _uuid = std::move(uuid);
480 }
481 void set_global_mem_cache_size(unsigned int global_mem_cache_size)
482 {
483 _global_mem_cache_size = global_mem_cache_size;
484 }
485
486 private:
487 char _name[256];
488 int _max_work_item_sizes_i[3];
489 bool _host_unified_memory = false;
490 int _major;
491 int _minor;
492 int _integrated = 0;
493 int _frequency;
494 // Set estimated value 3200000 kHz as default value.
495 unsigned int _memory_clock_rate = 3200000;
496 // Set estimated value 64 bits as default value.
497 unsigned int _memory_bus_width = 64;
498 unsigned int _global_mem_cache_size;
499 int _max_compute_units;
500 int _max_work_group_size;
501 int _max_sub_group_size;
502 int _max_work_items_per_compute_unit;
503 int _max_register_size_per_work_group;
504 size_t _global_mem_size;
505 size_t _local_mem_size;
506 size_t _max_mem_alloc_size;
507 size_t _max_nd_range_size[3];
508 int _max_nd_range_size_i[3];
509 uint32_t _device_id;
510 std::array<unsigned char, 16> _uuid;
511 };
512
513 static int get_major_version(const sycl::device &dev)
514 {
515 int major, minor;
516 detail::get_version(dev, major, minor);
517 return major;
518 }
519
520 static int get_minor_version(const sycl::device &dev)
521 {
522 int major, minor;
523 detail::get_version(dev, major, minor);
524 return minor;
525 }
526
527 static void get_device_info(device_info &out, const sycl::device &dev)
528 {
529 device_info prop;
530 prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
531
532 int major, minor;
533 detail::get_version(dev, major, minor);
534 prop.set_major_version(major);
535 prop.set_minor_version(minor);
536
537 prop.set_max_work_item_sizes(
538#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
539 // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
540 // is an enum class element
541 dev.get_info<sycl::info::device::max_work_item_sizes>());
542#else
543 // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
544 // an int
545 dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
546#endif
547 prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
548
549 prop.set_max_clock_frequency(
550 dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
551
552 prop.set_max_compute_units(
553 dev.get_info<sycl::info::device::max_compute_units>());
554 prop.set_max_work_group_size(
555 dev.get_info<sycl::info::device::max_work_group_size>());
556 prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
557 prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
558 prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
559
560#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
561 if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
562 {
563 unsigned int tmp =
564 dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
565 if (tmp != 0)
566 prop.set_memory_clock_rate(1000 * tmp);
567 }
568 if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
569 {
570 prop.set_memory_bus_width(
571 dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
572 }
573 if (dev.has(sycl::aspect::ext_intel_device_id))
574 {
575 prop.set_device_id(
576 dev.get_info<sycl::ext::intel::info::device::device_id>());
577 }
578 if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
579 {
580 prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
581 }
582#elif defined(_MSC_VER) && !defined(__clang__)
583#pragma message("get_device_info: querying memory_clock_rate and \
584 memory_bus_width are not supported by the compiler used. \
585 Use 3200000 kHz as memory_clock_rate default value. \
586 Use 64 bits as memory_bus_width default value.")
587#else
588#warning "get_device_info: querying memory_clock_rate and \
589 memory_bus_width are not supported by the compiler used. \
590 Use 3200000 kHz as memory_clock_rate default value. \
591 Use 64 bits as memory_bus_width default value."
592#endif
593
594 size_t max_sub_group_size = 1;
595 std::vector<size_t> sub_group_sizes =
596 dev.get_info<sycl::info::device::sub_group_sizes>();
597
598 for (const auto &sub_group_size : sub_group_sizes)
599 {
600 if (max_sub_group_size < sub_group_size)
601 max_sub_group_size = sub_group_size;
602 }
603
604 prop.set_max_sub_group_size(max_sub_group_size);
605
606 prop.set_max_work_items_per_compute_unit(
607 dev.get_info<sycl::info::device::max_work_group_size>());
608 int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
609 prop.set_max_nd_range_size(max_nd_range_size);
610
611 // Estimates max register size per work group, feel free to update the value
612 // according to device properties.
613 prop.set_max_register_size_per_work_group(65536);
614
615 prop.set_global_mem_cache_size(
616 dev.get_info<sycl::info::device::global_mem_cache_size>());
617 out = prop;
618 }
619
620 /// dpct device extension
621 class device_ext : public sycl::device {
622 typedef std::mutex mutex_type;
623
624 public:
625 device_ext() : sycl::device() {}
626 ~device_ext() {
627 std::lock_guard<mutex_type> lock(m_mutex);
628 clear_queues();
629 }
630 device_ext(const sycl::device &base) : sycl::device(base) {
631 std::lock_guard<mutex_type> lock(m_mutex);
632 init_queues();
633 }
634
635 int is_native_atomic_supported() { return 0; }
636 int get_major_version() const { return dpct::get_major_version(*this); }
637
638 int get_minor_version() const { return dpct::get_minor_version(*this); }
639
640 int get_max_compute_units() const {
641 return get_device_info().get_max_compute_units();
642 }
643
644 /// Return the maximum clock frequency of this device in KHz.
645 int get_max_clock_frequency() const {
646 return get_device_info().get_max_clock_frequency();
647 }
648
649 int get_integrated() const { return get_device_info().get_integrated(); }
650
651 int get_max_sub_group_size() const {
652 return get_device_info().get_max_sub_group_size();
653 }
654
655 int get_max_register_size_per_work_group() const {
656 return get_device_info().get_max_register_size_per_work_group();
657 }
658
659 int get_max_work_group_size() const {
660 return get_device_info().get_max_work_group_size();
661 }
662
663 int get_mem_base_addr_align() const {
664 return get_info<sycl::info::device::mem_base_addr_align>();
665 }
666
667 size_t get_global_mem_size() const {
668 return get_device_info().get_global_mem_size();
669 }
670
671 size_t get_max_mem_alloc_size() const {
672 return get_device_info().get_max_mem_alloc_size();
673 }
674
675 /// Get the number of bytes of free and total memory on the SYCL device.
676 /// \param [out] free_memory The number of bytes of free memory on the
677 /// SYCL device. \param [out] total_memory The number of bytes of total
678 /// memory on the SYCL device.
679 void get_memory_info(size_t &free_memory, size_t &total_memory) {
680 total_memory = get_device_info().get_global_mem_size();
681 const char *warning_info =
682 "get_memory_info: [warning] ext_intel_free_memory is not "
683 "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
684 "use total memory as free memory";
685#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
686 if (!has(sycl::aspect::ext_intel_free_memory)) {
687 std::cerr << warning_info << std::endl;
688 free_memory = total_memory;
689 } else {
690 free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
691 }
692#else
693 std::cerr << warning_info << std::endl;
694 free_memory = total_memory;
695#if defined(_MSC_VER) && !defined(__clang__)
696#pragma message("Querying the number of bytes of free memory is not supported")
697#else
698#warning "Querying the number of bytes of free memory is not supported"
699#endif
700#endif
701 }
702
703 void get_device_info(device_info &out) const {
704 dpct::get_device_info(out, *this);
705 }
706
707 device_info get_device_info() const {
708 device_info prop;
709 dpct::get_device_info(prop, *this);
710 return prop;
711 }
712
713 void reset() {
714 std::lock_guard<mutex_type> lock(m_mutex);
715 clear_queues();
716 init_queues();
717 }
718
719 sycl::queue &in_order_queue() { return _q_in_order; }
720
721 sycl::queue &out_of_order_queue() { return _q_out_of_order; }
722
723 sycl::queue &default_queue() { return in_order_queue(); }
724
725 void queues_wait_and_throw() {
726 std::unique_lock<mutex_type> lock(m_mutex);
727 lock.unlock();
728 for (auto &q : _queues) {
729 q.wait_and_throw();
730 }
731 // Guard the destruct of current_queues to make sure the ref count is
732 // safe.
733 lock.lock();
734 }
735
736 sycl::queue create_queue(bool enable_exception_handler = false) {
737 return create_in_order_queue(enable_exception_handler);
738 }
739
740 sycl::queue create_queue(sycl::device device,
741 bool enable_exception_handler = false) {
742 return create_in_order_queue(device, enable_exception_handler);
743 }
744
745 sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
746 std::lock_guard<mutex_type> lock(m_mutex);
747 return create_queue_impl(enable_exception_handler,
748 sycl::property::queue::in_order());
749 }
750
751 sycl::queue create_in_order_queue(sycl::device device,
752 bool enable_exception_handler = false) {
753 std::lock_guard<mutex_type> lock(m_mutex);
754 return create_queue_impl(device, enable_exception_handler,
755 sycl::property::queue::in_order());
756 }
757
758 sycl::queue create_out_of_order_queue(
759 bool enable_exception_handler = false) {
760 std::lock_guard<mutex_type> lock(m_mutex);
761 return create_queue_impl(enable_exception_handler);
762 }
763
764 void destroy_queue(sycl::queue queue) {
765 std::lock_guard<mutex_type> lock(m_mutex);
766 _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
767 [=](const sycl::queue &q) -> bool
768 {
769 return q == queue;
770 }),
771 _queues.end());
772 }
773 void set_saved_queue(sycl::queue q) {
774 std::lock_guard<mutex_type> lock(m_mutex);
775 _saved_queue = q;
776 }
777 sycl::queue get_saved_queue() const {
778 std::lock_guard<mutex_type> lock(m_mutex);
779 return _saved_queue;
780 }
781
782 private:
783 void clear_queues() { _queues.clear(); }
784
785 void init_queues() {
786 _q_in_order =
787 create_queue_impl(true, sycl::property::queue::in_order());
788 _q_out_of_order = create_queue_impl(true);
789 _saved_queue = default_queue();
790 }
791
792 /// Caller should acquire resource \p m_mutex before calling this
793 /// function.
794 template <class... Properties>
795 sycl::queue create_queue_impl(bool enable_exception_handler,
796 Properties... properties) {
797 sycl::async_handler eh = {};
798 if (enable_exception_handler) {
799 eh = exception_handler;
800 }
801 _queues.push_back(sycl::queue(
802 *this, eh,
803 sycl::property_list(
804#ifdef DPCT_PROFILING_ENABLED
805 sycl::property::queue::enable_profiling(),
806#endif
807 properties...)));
808
809 return _queues.back();
810 }
811
812 template <class... Properties>
813 sycl::queue create_queue_impl(sycl::device device,
814 bool enable_exception_handler,
815 Properties... properties) {
816 sycl::async_handler eh = {};
817 if (enable_exception_handler) {
818 eh = exception_handler;
819 }
820 _queues.push_back(sycl::queue(
821 device, eh,
822 sycl::property_list(
823#ifdef DPCT_PROFILING_ENABLED
824 sycl::property::queue::enable_profiling(),
825#endif
826 properties...)));
827
828 return _queues.back();
829 }
830
831 void get_version(int &major, int &minor) const {
832 detail::get_version(*this, major, minor);
833 }
834 sycl::queue _q_in_order, _q_out_of_order;
835 sycl::queue _saved_queue;
836 std::vector<sycl::queue> _queues;
837 mutable mutex_type m_mutex;
838 };
839
840
841 /// device manager
842 class dev_mgr
843 {
844 public:
845 device_ext &current_device()
846 {
847 unsigned int dev_id = current_device_id();
848 check_id(dev_id);
849 return *_devs[dev_id];
850 }
851 device_ext &cpu_device() const
852 {
853 std::lock_guard<std::recursive_mutex> lock(m_mutex);
854 if (_cpu_device == -1)
855 {
856 throw std::runtime_error("no valid cpu device");
857 }
858 else
859 {
860 return *_devs[_cpu_device];
861 }
862 }
863 device_ext &get_device(unsigned int id) const
864 {
865 std::lock_guard<std::recursive_mutex> lock(m_mutex);
866 check_id(id);
867 return *_devs[id];
868 }
869 unsigned int current_device_id() const
870 {
871 std::lock_guard<std::recursive_mutex> lock(m_mutex);
872 auto it = _thread2dev_map.find(get_tid());
873 if (it != _thread2dev_map.end())
874 return it->second;
875 return DEFAULT_DEVICE_ID;
876 }
877
878 /// Select device with a device ID.
879 /// \param [in] id The id of the device which can
880 /// be obtained through get_device_id(const sycl::device).
881 void select_device(unsigned int id)
882 {
883 std::lock_guard<std::recursive_mutex> lock(m_mutex);
884 check_id(id);
885 _thread2dev_map[get_tid()] = id;
886 }
887 unsigned int device_count() { return _devs.size(); }
888
889 unsigned int get_device_id(const sycl::device &dev)
890 {
891 unsigned int id = 0;
892 for (auto &dev_item : _devs)
893 {
894 if (*dev_item == dev)
895 {
896 return id;
897 }
898 id++;
899 }
900 return -1;
901 }
902
903 inline std::string get_preferred_gpu_platform_name() {
904 std::string result;
905
906 std::string filter = "";
907 char* env = getenv("ONEAPI_DEVICE_SELECTOR");
908 if (env) {
909 if (std::strstr(env, "level_zero")) {
910 filter = "level-zero";
911 }
912 else if (std::strstr(env, "opencl")) {
913 filter = "opencl";
914 }
915 else if (std::strstr(env, "cuda")) {
916 filter = "cuda";
917 }
918 else if (std::strstr(env, "hip")) {
919 filter = "hip";
920 }
921 else {
922 throw std::runtime_error("invalid device filter: " + std::string(env));
923 }
924 } else {
925 auto default_device = sycl::device(sycl::default_selector_v);
926 auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
927
928 if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
929 filter = "level-zero";
930 }
931 else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
932 filter = "cuda";
933 }
934 else if (std::strstr(default_platform_name.c_str(), "HIP")) {
935 filter = "hip";
936 }
937 }
938
939 auto platform_list = sycl::platform::get_platforms();
940
941 for (const auto& platform : platform_list) {
942 auto devices = platform.get_devices();
943 auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
944 return d.is_gpu();
945 });
946
947 if (gpu_dev == devices.end()) {
948 // cout << "platform [" << platform_name
949 // << "] does not contain GPU devices, skipping\n";
950 continue;
951 }
952
953 auto platform_name = platform.get_info<sycl::info::platform::name>();
954 std::string platform_name_low_case;
955 platform_name_low_case.resize(platform_name.size());
956
957 std::transform(
958 platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
959
960 if (platform_name_low_case.find(filter) == std::string::npos) {
961 // cout << "platform [" << platform_name
962 // << "] does not match with requested "
963 // << filter << ", skipping\n";
964 continue;
965 }
966
967 result = platform_name;
968 }
969
970 if (result.empty())
971 throw std::runtime_error("can not find preferred GPU platform");
972
973 return result;
974 }
975
976 template <class DeviceSelector>
977 std::enable_if_t<
978 std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
979 select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
980 {
981 sycl::device selected_device = sycl::device(selector);
982 unsigned int selected_device_id = get_device_id(selected_device);
983 select_device(selected_device_id);
984 }
985
986 /// Returns the instance of device manager singleton.
987 static dev_mgr &instance()
988 {
989 static dev_mgr d_m;
990 return d_m;
991 }
992 dev_mgr(const dev_mgr &) = delete;
993 dev_mgr &operator=(const dev_mgr &) = delete;
994 dev_mgr(dev_mgr &&) = delete;
995 dev_mgr &operator=(dev_mgr &&) = delete;
996
997 private:
998 mutable std::recursive_mutex m_mutex;
999 static bool compare_dev(sycl::device &device1, sycl::device &device2)
1000 {
1001 sycl::backend backend1 = device1.get_backend();
1002 sycl::backend backend2 = device2.get_backend();
1003 // levelzero backends always come first
1004 if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
1005 if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
1006 dpct::device_info prop1;
1007 dpct::get_device_info(prop1, device1);
1008 dpct::device_info prop2;
1009 dpct::get_device_info(prop2, device2);
1010 return prop1.get_max_compute_units() > prop2.get_max_compute_units();
1011 }
1012 static int convert_backend_index(std::string & backend) {
1013 if (backend == "ext_oneapi_level_zero:gpu") return 0;
1014 if (backend == "opencl:gpu") return 1;
1015 if (backend == "ext_oneapi_cuda:gpu") return 2;
1016 if (backend == "ext_oneapi_hip:gpu") return 3;
1017 if (backend == "opencl:cpu") return 4;
1018 if (backend == "opencl:acc") return 5;
1019 printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
1020 GGML_ABORT("fatal error");
1021 }
1022 static bool compare_backend(std::string &backend1, std::string &backend2) {
1023 return convert_backend_index(backend1) < convert_backend_index(backend2);
1024 }
1025 dev_mgr()
1026 {
1027 sycl::device default_device =
1028 sycl::device(sycl::default_selector_v);
1029 _devs.push_back(std::make_shared<device_ext>(default_device));
1030
1031 std::vector<sycl::device> sycl_all_devs;
1032 // Collect other devices except for the default device.
1033 if (default_device.is_cpu())
1034 _cpu_device = 0;
1035
1036 auto Platforms = sycl::platform::get_platforms();
1037 // Keep track of the number of devices per backend
1038 std::map<sycl::backend, size_t> DeviceNums;
1039 std::map<std::string, std::vector<sycl::device>> backend_devices;
1040 auto preferred_platform_name = get_preferred_gpu_platform_name();
1041
1042 while (!Platforms.empty()) {
1043 auto Platform = Platforms.back();
1044 Platforms.pop_back();
1045 auto platform_name = Platform.get_info<sycl::info::platform::name>();
1046 if (platform_name.compare(preferred_platform_name) != 0) {
1047 continue;
1048 }
1049 auto devices = Platform.get_devices();
1050 std::string backend_type = get_device_backend_and_type(devices[0]);
1051 for (const auto &device : devices) {
1052 backend_devices[backend_type].push_back(device);
1053 }
1054 }
1055
1056 std::vector<std::string> keys;
1057 for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
1058 keys.push_back(it->first);
1059 }
1060 std::sort(keys.begin(), keys.end(), compare_backend);
1061
1062 for (auto &key : keys) {
1063 std::vector<sycl::device> devs = backend_devices[key];
1064 std::sort(devs.begin(), devs.end(), compare_dev);
1065 for (const auto &dev : devs) {
1066 sycl_all_devs.push_back(dev);
1067 }
1068 }
1069
1070 for (auto &dev : sycl_all_devs)
1071 {
1072 if (dev == default_device)
1073 {
1074 continue;
1075 }
1076 _devs.push_back(std::make_shared<device_ext>(dev));
1077 if (_cpu_device == -1 && dev.is_cpu())
1078 {
1079 _cpu_device = _devs.size() - 1;
1080 }
1081 }
1082 }
1083 void check_id(unsigned int id) const
1084 {
1085 if (id >= _devs.size())
1086 {
1087 throw std::runtime_error("invalid device id");
1088 }
1089 }
1090 std::vector<std::shared_ptr<device_ext>> _devs;
1091 /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
1092 /// thread id in _thread2dev_map, which means default device should be used
1093 /// for the current thread.
1094 const unsigned int DEFAULT_DEVICE_ID = 0;
1095 /// thread-id to device-id map.
1096 std::map<unsigned int, unsigned int> _thread2dev_map;
1097 int _cpu_device = -1;
1098 };
1099
1100 static inline sycl::queue &get_default_queue()
1101 {
1102 return dev_mgr::instance().current_device().default_queue();
1103 }
1104
1105 namespace detail
1106 {
1107 enum class pointer_access_attribute
1108 {
1109 host_only = 0,
1110 device_only,
1111 host_device,
1112 end
1113 };
1114
1115 static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
1116 const void *ptr)
1117 {
1118 switch (sycl::get_pointer_type(ptr, q.get_context()))
1119 {
1120 case sycl::usm::alloc::unknown:
1121 return pointer_access_attribute::host_only;
1122 case sycl::usm::alloc::device:
1123 return pointer_access_attribute::device_only;
1124 case sycl::usm::alloc::shared:
1125 case sycl::usm::alloc::host:
1126 return pointer_access_attribute::host_device;
1127 }
1128 }
1129
1130 template <typename ArgT>
1131 inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
1132 {
1133 static_assert((unsigned char)library_data_t::library_data_t_size <=
1134 std::numeric_limits<unsigned char>::max() &&
1135 "library_data_t size exceeds limit.");
1136 static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
1137 return (std::uint64_t)Val;
1138 }
1139
1140 template <typename FirstT, typename... RestT>
1141 inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
1142 RestT... RestVal)
1143 {
1144 static_assert((std::uint8_t)library_data_t::library_data_t_size <=
1145 std::numeric_limits<unsigned char>::max() &&
1146 "library_data_t size exceeds limit.");
1147 static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
1148 static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
1149 return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
1150 }
1151
1152 class mem_mgr
1153 {
1154 mem_mgr()
1155 {
1156 // Reserved address space, no real memory allocation happens here.
1157#if defined(__linux__)
1158 mapped_address_space =
1159 (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
1160 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
1161#elif defined(_WIN64)
1162 mapped_address_space = (byte_t *)VirtualAlloc(
1163 NULL, // NULL specified as the base address parameter
1164 mapped_region_size, // Size of allocation
1165 MEM_RESERVE, // Allocate reserved pages
1166 PAGE_NOACCESS); // Protection = no access
1167#else
1168#error "Only support Windows and Linux."
1169#endif
1170 next_free = mapped_address_space;
1171 }
1172
1173 public:
1174 using buffer_id_t = int;
1175
1176 struct allocation
1177 {
1178 buffer_t buffer;
1179 byte_t *alloc_ptr;
1180 size_t size;
1181 };
1182
1183 ~mem_mgr()
1184 {
1185#if defined(__linux__)
1186 munmap(mapped_address_space, mapped_region_size);
1187#elif defined(_WIN64)
1188 VirtualFree(mapped_address_space, 0, MEM_RELEASE);
1189#else
1190#error "Only support Windows and Linux."
1191#endif
1192 }
1193
1194 mem_mgr(const mem_mgr &) = delete;
1195 mem_mgr &operator=(const mem_mgr &) = delete;
1196 mem_mgr(mem_mgr &&) = delete;
1197 mem_mgr &operator=(mem_mgr &&) = delete;
1198
1199 /// Allocate
1200 void *mem_alloc(size_t size)
1201 {
1202 if (!size)
1203 return nullptr;
1204 std::lock_guard<std::mutex> lock(m_mutex);
1205 if (next_free + size > mapped_address_space + mapped_region_size)
1206 {
1207 throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
1208 }
1209 // Allocation
1210 sycl::range<1> r(size);
1211 buffer_t buf(r);
1212 allocation A{buf, next_free, size};
1213 // Map allocation to device pointer
1214 void *result = next_free;
1215 m_map.emplace(next_free + size, A);
1216 // Update pointer to the next free space.
1217 next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
1218
1219 return result;
1220 }
1221
1222 /// Deallocate
1223 void mem_free(const void *ptr)
1224 {
1225 if (!ptr)
1226 return;
1227 std::lock_guard<std::mutex> lock(m_mutex);
1228 auto it = get_map_iterator(ptr);
1229 m_map.erase(it);
1230 }
1231
1232 /// map: device pointer -> allocation(buffer, alloc_ptr, size)
1233 allocation translate_ptr(const void *ptr)
1234 {
1235 std::lock_guard<std::mutex> lock(m_mutex);
1236 auto it = get_map_iterator(ptr);
1237 return it->second;
1238 }
1239
1240 /// Check if the pointer represents device pointer or not.
1241 bool is_device_ptr(const void *ptr) const
1242 {
1243 std::lock_guard<std::mutex> lock(m_mutex);
1244 return (mapped_address_space <= ptr) &&
1245 (ptr < mapped_address_space + mapped_region_size);
1246 }
1247
1248 /// Returns the instance of memory manager singleton.
1249 static mem_mgr &instance()
1250 {
1251 static mem_mgr m;
1252 return m;
1253 }
1254
1255 private:
1256 std::map<byte_t *, allocation> m_map;
1257 mutable std::mutex m_mutex;
1258 byte_t *mapped_address_space;
1259 byte_t *next_free;
1260 const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
1261 const size_t alignment = 256;
1262 /// This padding may be defined to some positive value to debug
1263 /// out of bound accesses.
1264 const size_t extra_padding = 0;
1265
1266 std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
1267 {
1268 auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
1269 if (it == m_map.end())
1270 {
1271 // Not a virtual pointer.
1272 throw std::runtime_error("can not get buffer from non-virtual pointer");
1273 }
1274 const allocation &alloc = it->second;
1275 if (ptr < alloc.alloc_ptr)
1276 {
1277 // Out of bound.
1278 // This may happen if there's a gap between allocations due to alignment
1279 // or extra padding and pointer points to this gap.
1280 throw std::runtime_error("invalid virtual pointer");
1281 }
1282 return it;
1283 }
1284 };
1285
1286 template <class T, memory_region Memory, size_t Dimension>
1287 class accessor;
1288 template <memory_region Memory, class T = byte_t>
1289 class memory_traits
1290 {
1291 public:
1292 static constexpr sycl::access::target target =
1293 sycl::access::target::device;
1294 static constexpr sycl::access_mode mode =
1295 (Memory == constant) ? sycl::access_mode::read
1296 : sycl::access_mode::read_write;
1297 static constexpr size_t type_size = sizeof(T);
1298 using element_t =
1299 typename std::conditional<Memory == constant, const T, T>::type;
1300 using value_t = typename std::remove_cv<T>::type;
1301 template <size_t Dimension = 1>
1302 using accessor_t = typename std::conditional<
1303 Memory == local, sycl::local_accessor<value_t, Dimension>,
1304 sycl::accessor<T, Dimension, mode, target>>::type;
1305 using pointer_t = T *;
1306 };
1307
1308 static inline void *dpct_malloc(size_t size, sycl::queue &q)
1309 {
1310 return sycl::malloc_device(size, q.get_device(), q.get_context());
1311 }
1312
1313#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
1314 static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
1315 sycl::queue &q)
1316 {
1317 pitch = PITCH_DEFAULT_ALIGN(x);
1318 return dpct_malloc(pitch * y * z, q);
1319 }
1320
1321 /**
1322 * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
1323 * @tparam valueT The type of the element to be set.
1324 * @param [in] q The queue in which the operation is done.
1325 * @param [in] dev_ptr Pointer to the virtual device memory address.
1326 * @param [in] value The value to be set.
1327 * @param [in] size Number of elements to be set to the value.
1328 * @return An event representing the memset operation.
1329 */
1330 template <typename valueT>
1331 static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
1332 valueT value, size_t size)
1333 {
1334 return q.fill(dev_ptr, value, size);
1335 }
1336
1337 /**
1338 * @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
1339 * @tparam valueT The type of the element to be set.
1340 * @param [in] q The queue in which the operation is done.
1341 * @param [in] data Pointer to the pitched device memory region.
1342 * @param [in] value The value to be set.
1343 * @param [in] size 3D memory region by number of elements.
1344 * @return An event list representing the memset operations.
1345 */
1346 template <typename valueT>
1347 static inline std::vector<sycl::event>
1348 dpct_memset(sycl::queue &q, pitched_data data, valueT value,
1349 sycl::range<3> size)
1350 {
1351 std::vector<sycl::event> event_list;
1352 size_t slice = data.get_pitch() * data.get_y();
1353 unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
1354 for (size_t z = 0; z < size.get(2); ++z)
1355 {
1356 unsigned char *data_ptr = data_surface;
1357 for (size_t y = 0; y < size.get(1); ++y)
1358 {
1359 event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
1360 data_ptr += data.get_pitch();
1361 }
1362 data_surface += slice;
1363 }
1364 return event_list;
1365 }
1366
1367 /**
1368 * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
1369 * @tparam valueT The type of the element to be set.
1370 * @param [in] q The queue in which the operation is done.
1371 * @param [in] ptr Pointer to the virtual device memory.
1372 * @param [in] pitch The pitch size by number of elements, including padding.
1373 * @param [in] val The value to be set.
1374 * @param [in] x The width of memory region by number of elements.
1375 * @param [in] y The height of memory region by number of elements.
1376 * @return An event list representing the memset operations.
1377 */
1378 template <typename valueT>
1379 static inline std::vector<sycl::event>
1380 dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
1381 size_t y)
1382 {
1383 return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
1384 sycl::range<3>(x, y, 1));
1385 }
1386
1387 static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
1388 const void *from_ptr,
1389 memcpy_direction dir)
1390 {
1391 switch (dir)
1392 {
1393 case memcpy_direction::host_to_host:
1394 case memcpy_direction::host_to_device:
1395 case memcpy_direction::device_to_host:
1396 case memcpy_direction::device_to_device:
1397 return dir;
1398 case memcpy_direction::automatic:
1399 {
1400 // table[to_attribute][from_attribute]
1401 static const memcpy_direction
1402 direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
1403 [static_cast<unsigned>(pointer_access_attribute::end)] =
1404 {{memcpy_direction::host_to_host,
1405 memcpy_direction::device_to_host,
1406 memcpy_direction::host_to_host},
1407 {memcpy_direction::host_to_device,
1408 memcpy_direction::device_to_device,
1409 memcpy_direction::device_to_device},
1410 {memcpy_direction::host_to_host,
1411 memcpy_direction::device_to_device,
1412 memcpy_direction::device_to_device}};
1413 return direction_table[static_cast<unsigned>(get_pointer_attribute(
1414 q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
1415 }
1416 default:
1417 throw std::runtime_error("dpct_memcpy: invalid direction value");
1418 }
1419 }
1420
1421 static sycl::event
1422 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
1423 memcpy_direction direction,
1424 const std::vector<sycl::event> &dep_events = {})
1425 {
1426 if (!size)
1427 return sycl::event{};
1428 return q.memcpy(to_ptr, from_ptr, size, dep_events);
1429 GGML_UNUSED(direction);
1430 }
1431
1432 // Get actual copy range and make sure it will not exceed range.
1433 static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
1434 size_t pitch)
1435 {
1436 return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
1437 }
1438
1439 static inline size_t get_offset(sycl::id<3> id, size_t slice,
1440 size_t pitch)
1441 {
1442 return slice * id.get(2) + pitch * id.get(1) + id.get(0);
1443 }
1444
1445 /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
1446 /// and \p from_range to another specified by \p to_ptr and \p to_range.
1447 static inline std::vector<sycl::event>
1448 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1449 sycl::range<3> to_range, sycl::range<3> from_range,
1450 sycl::id<3> to_id, sycl::id<3> from_id,
1451 sycl::range<3> size, memcpy_direction direction,
1452 const std::vector<sycl::event> &dep_events = {})
1453 {
1454 // RAII for host pointer
1455 class host_buffer
1456 {
1457 void *_buf;
1458 size_t _size;
1459 sycl::queue &_q;
1460 const std::vector<sycl::event> &_deps; // free operation depends
1461
1462 public:
1463 host_buffer(size_t size, sycl::queue &q,
1464 const std::vector<sycl::event> &deps)
1465 : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
1466 void *get_ptr() const { return _buf; }
1467 size_t get_size() const { return _size; }
1468 ~host_buffer()
1469 {
1470 if (_buf)
1471 {
1472 _q.submit([&](sycl::handler &cgh)
1473 {
1474 cgh.depends_on(_deps);
1475 cgh.host_task([buf = _buf] { std::free(buf); }); });
1476 }
1477 }
1478 };
1479 std::vector<sycl::event> event_list;
1480
1481 size_t to_slice = to_range.get(1) * to_range.get(0),
1482 from_slice = from_range.get(1) * from_range.get(0);
1483 unsigned char *to_surface =
1484 (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
1485 const unsigned char *from_surface =
1486 (const unsigned char *)from_ptr +
1487 get_offset(from_id, from_slice, from_range.get(0));
1488
1489 if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
1490 {
1491 return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
1492 direction, dep_events)};
1493 }
1494 direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
1495 size_t size_slice = size.get(1) * size.get(0);
1496 switch (direction)
1497 {
1498 case host_to_host:
1499 for (size_t z = 0; z < size.get(2); ++z)
1500 {
1501 unsigned char *to_ptr = to_surface;
1502 const unsigned char *from_ptr = from_surface;
1503 if (to_range.get(0) == from_range.get(0) &&
1504 to_range.get(0) == size.get(0))
1505 {
1506 event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
1507 direction, dep_events));
1508 }
1509 else
1510 {
1511 for (size_t y = 0; y < size.get(1); ++y)
1512 {
1513 event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
1514 direction, dep_events));
1515 to_ptr += to_range.get(0);
1516 from_ptr += from_range.get(0);
1517 }
1518 }
1519 to_surface += to_slice;
1520 from_surface += from_slice;
1521 }
1522 break;
1523 case host_to_device:
1524 {
1525 host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
1526 event_list);
1527 std::vector<sycl::event> host_events;
1528 if (to_slice == size_slice)
1529 {
1530 // Copy host data to a temp host buffer with the shape of target.
1531 host_events =
1532 dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
1533 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
1534 host_to_host, dep_events);
1535 }
1536 else
1537 {
1538 // Copy host data to a temp host buffer with the shape of target.
1539 host_events = dpct_memcpy(
1540 q, buf.get_ptr(), from_surface, to_range, from_range,
1541 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
1542 // If has padding data, not sure whether it is useless. So fill temp
1543 // buffer with it.
1544 std::vector<sycl::event>{
1545 dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
1546 device_to_host, dep_events)});
1547 }
1548 // Copy from temp host buffer to device with only one submit.
1549 event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
1550 buf.get_size(), host_to_device,
1551 host_events));
1552 break;
1553 }
1554 case device_to_host:
1555 {
1556 host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
1557 event_list);
1558 // Copy from host temp buffer to host target with reshaping.
1559 event_list = dpct_memcpy(
1560 q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
1561 sycl::id<3>(0, 0, 0), size, host_to_host,
1562 // Copy from device to temp host buffer with only one submit.
1563 std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
1564 buf.get_size(),
1565 device_to_host, dep_events)});
1566 break;
1567 }
1568 case device_to_device:
1569 event_list.push_back(q.submit([&](sycl::handler &cgh){
1570 cgh.depends_on(dep_events);
1571 cgh.parallel_for<class dpct_memcpy_3d_detail>(
1572 size,
1573 [=](sycl::id<3> id) {
1574 to_surface[get_offset(id, to_slice, to_range.get(0))] =
1575 from_surface[get_offset(id, from_slice, from_range.get(0))];
1576 }); }));
1577 break;
1578 default:
1579 throw std::runtime_error("dpct_memcpy: invalid direction value");
1580 }
1581 return event_list;
1582 }
1583
1584 /// memcpy 2D/3D matrix specified by pitched_data.
1585 static inline std::vector<sycl::event>
1586 dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
1587 pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
1588 memcpy_direction direction = automatic)
1589 {
1590 return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
1591 sycl::range<3>(to.get_pitch(), to.get_y(), 1),
1592 sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
1593 size, direction);
1594 }
1595
1596 /// memcpy 2D matrix with pitch.
1597 static inline std::vector<sycl::event>
1598 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1599 size_t to_pitch, size_t from_pitch, size_t x, size_t y,
1600 memcpy_direction direction = automatic)
1601 {
1602 return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
1603 sycl::range<3>(from_pitch, y, 1),
1604 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
1605 sycl::range<3>(x, y, 1), direction);
1606 }
1607
1608 namespace deprecated
1609 {
1610
1611 template <typename T, sycl::usm::alloc AllocKind>
1612 class usm_allocator
1613 {
1614 private:
1615 using Alloc = sycl::usm_allocator<T, AllocKind>;
1616 Alloc _impl;
1617
1618 public:
1619 using value_type = typename std::allocator_traits<Alloc>::value_type;
1620 using pointer = typename std::allocator_traits<Alloc>::pointer;
1621 using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
1622 using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
1623 using const_void_pointer =
1624 typename std::allocator_traits<Alloc>::const_void_pointer;
1625 using reference = typename std::allocator_traits<Alloc>::value_type &;
1626 using const_reference =
1627 const typename std::allocator_traits<Alloc>::value_type &;
1628 using difference_type =
1629 typename std::allocator_traits<Alloc>::difference_type;
1630 using size_type = typename std::allocator_traits<Alloc>::size_type;
1631 using propagate_on_container_copy_assignment = typename std::allocator_traits<
1632 Alloc>::propagate_on_container_copy_assignment;
1633 using propagate_on_container_move_assignment = typename std::allocator_traits<
1634 Alloc>::propagate_on_container_move_assignment;
1635 using propagate_on_container_swap =
1636 typename std::allocator_traits<Alloc>::propagate_on_container_swap;
1637 using is_always_equal =
1638 typename std::allocator_traits<Alloc>::is_always_equal;
1639
1640 template <typename U>
1641 struct rebind
1642 {
1643 typedef usm_allocator<U, AllocKind> other;
1644 };
1645
1646 usm_allocator() : _impl(dpct::get_default_queue()) {}
1647 ~usm_allocator() {}
1648 usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
1649 usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
1650 pointer address(reference r) { return &r; }
1651 const_pointer address(const_reference r) { return &r; }
1652 pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
1653 {
1654 return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
1655 }
1656 void deallocate(pointer p, size_type cnt)
1657 {
1658 std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
1659 }
1660 size_type max_size() const
1661 {
1662 return std::allocator_traits<Alloc>::max_size(_impl);
1663 }
1664 bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
1665 bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
1666 };
1667
1668 } // namespace deprecated
1669
1670 inline void dpct_free(void *ptr,
1671 const sycl::queue &q)
1672 {
1673 if (ptr)
1674 {
1675 sycl::free(ptr, q.get_context());
1676 }
1677 }
1678
1679 template <typename T>
1680 inline auto get_memory(const void *x)
1681 {
1682 T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
1683 return new_x;
1684 }
1685
1686 template <typename T>
1687 inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
1688 {
1689 using Ty = typename DataType<T>::T2;
1690 Ty s_h;
1691 if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
1692 detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
1693 .wait();
1694 else
1695 s_h = *reinterpret_cast<const Ty *>(s);
1696 return s_h;
1697 }
1698
1699 } // namespace detail
1700
1701 template <typename T>
1702 inline auto get_value(const T *s, sycl::queue &q)
1703 {
1704 return detail::get_value(s, q);
1705 }
1706
1707 namespace detail
1708 {
1709 template <class Ta, class Tb, class Tc, class Ts>
1710 inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
1711 int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1712 const void * beta, void * c, int ldc) {
1713 Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1714 Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1715 auto data_a = get_memory<const Ta>(a);
1716 auto data_b = get_memory<const Tb>(b);
1717 auto data_c = get_memory<Tc>(c);
1718 oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
1719 lda, data_b, ldb, beta_value, data_c, ldc);
1720 }
1721
1722 template <typename VecT, class BinaryOperation, class = void>
1723 class vectorized_binary
1724 {
1725 public:
1726 inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1727 {
1728 VecT v4;
1729 for (size_t i = 0; i < v4.size(); ++i)
1730 {
1731 v4[i] = binary_op(a[i], b[i]);
1732 }
1733 return v4;
1734 }
1735 };
1736
1737 template <typename VecT, class BinaryOperation>
1738 class vectorized_binary<
1739 VecT, BinaryOperation,
1740 std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
1741 {
1742 public:
1743 inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1744 {
1745 return binary_op(a, b).template as<VecT>();
1746 }
1747 };
1748
1749 template <class Ta, class Tb, class Tc, class Ts>
1750 inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1751 int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1752 int ldb, const void * beta, void ** c, int ldc, int batch_size,
1753 matrix_info_t<float> * matrix_info) {
1754 Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1755 Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1756
1757 matrix_info->transpose_info[0] = a_trans;
1758 matrix_info->transpose_info[1] = b_trans;
1759 matrix_info->value_info[0] = alpha_value;
1760 matrix_info->value_info[1] = beta_value;
1761 matrix_info->size_info[0] = m;
1762 matrix_info->size_info[1] = n;
1763 matrix_info->size_info[2] = k;
1764 matrix_info->ld_info[0] = lda;
1765 matrix_info->ld_info[1] = ldb;
1766 matrix_info->ld_info[2] = ldc;
1767 matrix_info->groupsize_info = batch_size;
1768
1769 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1770 q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1771 matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1772 reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1773 reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1774 reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1775 matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1776 }
1777
1778 template <class Ta, class Tb, class Tc, class Ts>
1779 inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1780 int m, int n, int k, const void * alpha, const void * a, int lda,
1781 long long int stride_a, const void * b, int ldb, long long int stride_b,
1782 const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
1783 Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1784 Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1785 auto data_a = get_memory<const Ta>(a);
1786 auto data_b = get_memory<const Tb>(b);
1787 auto data_c = get_memory<Tc>(c);
1788 oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
1789 data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1790 data_c, ldc, stride_c, batch_size);
1791 }
1792
1793 } // namespace detail
1794
1795 template <typename VecT, class BinaryOperation>
1796 inline unsigned vectorized_binary(unsigned a, unsigned b,
1797 const BinaryOperation binary_op)
1798 {
1799 sycl::vec<unsigned, 1> v0{a}, v1{b};
1800 auto v2 = v0.as<VecT>();
1801 auto v3 = v1.as<VecT>();
1802 auto v4 =
1803 detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1804 v0 = v4.template as<sycl::vec<unsigned, 1>>();
1805 return v0;
1806 }
1807
1808 static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
1809 memcpy_direction direction = automatic,
1810 sycl::queue &q = dpct::get_default_queue())
1811 {
1812 detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
1813 }
1814
1815 static inline unsigned int select_device(unsigned int id)
1816 {
1817 dev_mgr::instance().select_device(id);
1818 return id;
1819 }
1820
1821 template <typename T>
1822 T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
1823 unsigned int logical_sub_group_size = 32)
1824 {
1825 unsigned int id = g.get_local_linear_id();
1826 unsigned int start_index =
1827 id / logical_sub_group_size * logical_sub_group_size;
1828 unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
1829 return sycl::select_from_group(g, x,
1830 target_offset < logical_sub_group_size
1831 ? start_index + target_offset
1832 : id);
1833 }
1834
1835 template <typename T1, typename T2>
1836 using dot_product_acc_t = std::conditional_t<
1837 std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1838 uint32_t,
1839 int32_t>;
1840
1841 template <typename T>
1842 sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
1843 return sycl::vec<T, 1>(val)
1844 .template as<sycl::vec<
1845 std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
1846 4>>()
1847 .template convert<T>();
1848 }
1849
1850 template <typename T1, typename T2, typename T3>
1851 inline auto dp4a(T1 a, T2 b, T3 c) {
1852 dot_product_acc_t<T1, T2> res = c;
1853 auto va = extract_and_sign_or_zero_extend4(a);
1854 auto vb = extract_and_sign_or_zero_extend4(b);
1855 res += va[0] * vb[0];
1856 res += va[1] * vb[1];
1857 res += va[2] * vb[2];
1858 res += va[3] * vb[3];
1859 return res;
1860 }
1861
1862 struct sub_sat
1863 {
1864 template <typename T>
1865 auto operator()(const T x, const T y) const
1866 {
1867 return sycl::sub_sat(x, y);
1868 }
1869 };
1870
1871 template <typename S, typename T>
1872 inline T vectorized_min(T a, T b)
1873 {
1874 sycl::vec<T, 1> v0{a}, v1{b};
1875 auto v2 = v0.template as<S>();
1876 auto v3 = v1.template as<S>();
1877 auto v4 = sycl::min(v2, v3);
1878 v0 = v4.template as<sycl::vec<T, 1>>();
1879 return v0;
1880 }
1881
1882 inline float pow(const float a, const int b) { return sycl::pown(a, b); }
1883 inline double pow(const double a, const int b) { return sycl::pown(a, b); }
1884 inline float pow(const float a, const float b) { return sycl::pow(a, b); }
1885 inline double pow(const double a, const double b) { return sycl::pow(a, b); }
1886 template <typename T, typename U>
1887 inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
1888 pow(const T a, const U b)
1889 {
1890 return sycl::pow(a, static_cast<T>(b));
1891 }
1892 template <typename T, typename U>
1893 inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
1894 pow(const T a, const U b)
1895 {
1896 return sycl::pow(static_cast<double>(a), static_cast<double>(b));
1897 }
1898
1899 inline double min(const double a, const float b)
1900 {
1901 return sycl::fmin(a, static_cast<double>(b));
1902 }
1903 inline double min(const float a, const double b)
1904 {
1905 return sycl::fmin(static_cast<double>(a), b);
1906 }
1907 inline float min(const float a, const float b) { return sycl::fmin(a, b); }
1908 inline double min(const double a, const double b) { return sycl::fmin(a, b); }
1909 inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
1910 {
1911 return sycl::min(a, static_cast<std::uint32_t>(b));
1912 }
1913 inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
1914 {
1915 return sycl::min(static_cast<std::uint32_t>(a), b);
1916 }
1917 inline std::int32_t min(const std::int32_t a, const std::int32_t b)
1918 {
1919 return sycl::min(a, b);
1920 }
1921 inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
1922 {
1923 return sycl::min(a, b);
1924 }
1925 inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
1926 {
1927 return sycl::min(a, static_cast<std::uint64_t>(b));
1928 }
1929 inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
1930 {
1931 return sycl::min(static_cast<std::uint64_t>(a), b);
1932 }
1933 inline std::int64_t min(const std::int64_t a, const std::int64_t b)
1934 {
1935 return sycl::min(a, b);
1936 }
1937 inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
1938 {
1939 return sycl::min(a, b);
1940 }
1941 inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
1942 {
1943 return sycl::min(a, static_cast<std::uint64_t>(b));
1944 }
1945 inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
1946 {
1947 return sycl::min(static_cast<std::uint64_t>(a), b);
1948 }
1949 inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
1950 {
1951 return sycl::min(a, static_cast<std::uint64_t>(b));
1952 }
1953 inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
1954 {
1955 return sycl::min(static_cast<std::uint64_t>(a), b);
1956 }
1957 // max function overloads.
1958 // For floating-point types, `float` or `double` arguments are acceptable.
1959 // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
1960 // `std::int64_t` type arguments are acceptable.
1961 inline double max(const double a, const float b)
1962 {
1963 return sycl::fmax(a, static_cast<double>(b));
1964 }
1965 inline double max(const float a, const double b)
1966 {
1967 return sycl::fmax(static_cast<double>(a), b);
1968 }
1969 inline float max(const float a, const float b) { return sycl::fmax(a, b); }
1970 inline double max(const double a, const double b) { return sycl::fmax(a, b); }
1971 inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
1972 {
1973 return sycl::max(a, static_cast<std::uint32_t>(b));
1974 }
1975 inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
1976 {
1977 return sycl::max(static_cast<std::uint32_t>(a), b);
1978 }
1979 inline std::int32_t max(const std::int32_t a, const std::int32_t b)
1980 {
1981 return sycl::max(a, b);
1982 }
1983 inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
1984 {
1985 return sycl::max(a, b);
1986 }
1987 inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
1988 {
1989 return sycl::max(a, static_cast<std::uint64_t>(b));
1990 }
1991 inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
1992 {
1993 return sycl::max(static_cast<std::uint64_t>(a), b);
1994 }
1995 inline std::int64_t max(const std::int64_t a, const std::int64_t b)
1996 {
1997 return sycl::max(a, b);
1998 }
1999 inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
2000 {
2001 return sycl::max(a, b);
2002 }
2003 inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
2004 {
2005 return sycl::max(a, static_cast<std::uint64_t>(b));
2006 }
2007 inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
2008 {
2009 return sycl::max(static_cast<std::uint64_t>(a), b);
2010 }
2011 inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
2012 {
2013 return sycl::max(a, static_cast<std::uint64_t>(b));
2014 }
2015 inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
2016 {
2017 return sycl::max(static_cast<std::uint64_t>(a), b);
2018 }
2019
2020 inline void
2021 has_capability_or_fail(const sycl::device &dev,
2022 const std::initializer_list<sycl::aspect> &props)
2023 {
2024 for (const auto &it : props)
2025 {
2026 if (dev.has(it))
2027 continue;
2028 switch (it)
2029 {
2030 case sycl::aspect::fp64:
2031 throw std::runtime_error("'double' is not supported in '" +
2032 dev.get_info<sycl::info::device::name>() +
2033 "' device");
2034 break;
2035 case sycl::aspect::fp16:
2036 throw std::runtime_error("'half' is not supported in '" +
2037 dev.get_info<sycl::info::device::name>() +
2038 "' device");
2039 break;
2040 default:
2041#define __SYCL_ASPECT(ASPECT, ID) \
2042 case sycl::aspect::ASPECT: \
2043 return #ASPECT;
2044#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
2045#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
2046 auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
2047 {
2048 switch (AspectNum)
2049 {
2050#include <sycl/info/aspects.def>
2051#include <sycl/info/aspects_deprecated.def>
2052 default:
2053 return "unknown aspect";
2054 }
2055 };
2056#undef __SYCL_ASPECT_DEPRECATED_ALIAS
2057#undef __SYCL_ASPECT_DEPRECATED
2058#undef __SYCL_ASPECT
2059 throw std::runtime_error(
2060 "'" + getAspectNameStr(it) + "' is not supported in '" +
2061 dev.get_info<sycl::info::device::name>() + "' device");
2062 }
2063 break;
2064 }
2065 }
2066
2067 static inline unsigned int get_current_device_id()
2068 {
2069 return dev_mgr::instance().current_device_id();
2070 }
2071
2072 static inline device_ext &get_current_device()
2073 {
2074 return dev_mgr::instance().current_device();
2075 }
2076
2077 static inline device_ext &get_device(unsigned int id)
2078 {
2079 return dev_mgr::instance().get_device(id);
2080 }
2081
2082 static inline sycl::queue &get_in_order_queue()
2083 {
2084 return dev_mgr::instance().current_device().in_order_queue();
2085 }
2086
2087 static sycl::event
2088 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
2089 memcpy_direction direction,
2090 const std::vector<sycl::event> &dep_events = {})
2091 {
2092 if (!size)
2093 return sycl::event{};
2094 return q.memcpy(to_ptr, from_ptr, size, dep_events);
2095 GGML_UNUSED(direction);
2096 }
2097
2098 // Get actual copy range and make sure it will not exceed range.
2099 static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
2100 size_t pitch)
2101 {
2102 return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
2103 }
2104
2105 static inline size_t get_offset(sycl::id<3> id, size_t slice,
2106 size_t pitch)
2107 {
2108 return slice * id.get(2) + pitch * id.get(1) + id.get(0);
2109 }
2110
2111 /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
2112 /// and \p from_range to another specified by \p to_ptr and \p to_range.
2113 static inline std::vector<sycl::event>
2114 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2115 sycl::range<3> to_range, sycl::range<3> from_range,
2116 sycl::id<3> to_id, sycl::id<3> from_id,
2117 sycl::range<3> size, memcpy_direction direction,
2118 const std::vector<sycl::event> &dep_events = {})
2119 {
2120 // RAII for host pointer
2121 class host_buffer
2122 {
2123 void *_buf;
2124 size_t _size;
2125 sycl::queue &_q;
2126 const std::vector<sycl::event> &_deps; // free operation depends
2127
2128 public:
2129 host_buffer(size_t size, sycl::queue &q,
2130 const std::vector<sycl::event> &deps)
2131 : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
2132 void *get_ptr() const { return _buf; }
2133 size_t get_size() const { return _size; }
2134 ~host_buffer()
2135 {
2136 if (_buf)
2137 {
2138 _q.submit([&](sycl::handler &cgh)
2139 {
2140 cgh.depends_on(_deps);
2141 cgh.host_task([buf = _buf] { std::free(buf); }); });
2142 }
2143 }
2144 };
2145 std::vector<sycl::event> event_list;
2146
2147 size_t to_slice = to_range.get(1) * to_range.get(0),
2148 from_slice = from_range.get(1) * from_range.get(0);
2149 unsigned char *to_surface =
2150 (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
2151 const unsigned char *from_surface =
2152 (const unsigned char *)from_ptr +
2153 get_offset(from_id, from_slice, from_range.get(0));
2154
2155 if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
2156 {
2157 return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
2158 direction, dep_events)};
2159 }
2160 direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
2161 size_t size_slice = size.get(1) * size.get(0);
2162 switch (direction)
2163 {
2164 case host_to_host:
2165 for (size_t z = 0; z < size.get(2); ++z)
2166 {
2167 unsigned char *to_ptr = to_surface;
2168 const unsigned char *from_ptr = from_surface;
2169 if (to_range.get(0) == from_range.get(0) &&
2170 to_range.get(0) == size.get(0))
2171 {
2172 event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
2173 direction, dep_events));
2174 }
2175 else
2176 {
2177 for (size_t y = 0; y < size.get(1); ++y)
2178 {
2179 event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
2180 direction, dep_events));
2181 to_ptr += to_range.get(0);
2182 from_ptr += from_range.get(0);
2183 }
2184 }
2185 to_surface += to_slice;
2186 from_surface += from_slice;
2187 }
2188 break;
2189 case host_to_device:
2190 {
2191 host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
2192 event_list);
2193 std::vector<sycl::event> host_events;
2194 if (to_slice == size_slice)
2195 {
2196 // Copy host data to a temp host buffer with the shape of target.
2197 host_events =
2198 dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
2199 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
2200 host_to_host, dep_events);
2201 }
2202 else
2203 {
2204 // Copy host data to a temp host buffer with the shape of target.
2205 host_events = dpct_memcpy(
2206 q, buf.get_ptr(), from_surface, to_range, from_range,
2207 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
2208 // If has padding data, not sure whether it is useless. So fill temp
2209 // buffer with it.
2210 std::vector<sycl::event>{
2211 dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
2212 device_to_host, dep_events)});
2213 }
2214 // Copy from temp host buffer to device with only one submit.
2215 event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
2216 buf.get_size(), host_to_device,
2217 host_events));
2218 break;
2219 }
2220 case device_to_host:
2221 {
2222 host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
2223 event_list);
2224 // Copy from host temp buffer to host target with reshaping.
2225 event_list = dpct_memcpy(
2226 q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
2227 sycl::id<3>(0, 0, 0), size, host_to_host,
2228 // Copy from device to temp host buffer with only one submit.
2229 std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
2230 buf.get_size(),
2231 device_to_host, dep_events)});
2232 break;
2233 }
2234 case device_to_device:
2235 event_list.push_back(q.submit([&](sycl::handler &cgh)
2236 {
2237 cgh.depends_on(dep_events);
2238 cgh.parallel_for<class dpct_memcpy_3d_detail>(
2239 size,
2240 [=](sycl::id<3> id) {
2241 to_surface[get_offset(id, to_slice, to_range.get(0))] =
2242 from_surface[get_offset(id, from_slice, from_range.get(0))];
2243 }); }));
2244 break;
2245 default:
2246 throw std::runtime_error("dpct_memcpy: invalid direction value");
2247 }
2248 return event_list;
2249 }
2250
2251 /// memcpy 2D/3D matrix specified by pitched_data.
2252 static inline std::vector<sycl::event>
2253 dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
2254 pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
2255 memcpy_direction direction = automatic)
2256 {
2257 return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
2258 sycl::range<3>(to.get_pitch(), to.get_y(), 1),
2259 sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
2260 size, direction);
2261 }
2262
2263 /// memcpy 2D matrix with pitch.
2264 static inline std::vector<sycl::event>
2265 dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2266 size_t to_pitch, size_t from_pitch, size_t x, size_t y,
2267 memcpy_direction direction = automatic)
2268 {
2269 return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
2270 sycl::range<3>(from_pitch, y, 1),
2271 sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
2272 sycl::range<3>(x, y, 1), direction);
2273 }
2274
2275 inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
2276 int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2277 library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2278 library_data_t scaling_type) {
2279 if (scaling_type == library_data_t::real_float &&
2280 c_type == library_data_t::complex_float)
2281 {
2282 scaling_type = library_data_t::complex_float;
2283 }
2284 else if (scaling_type == library_data_t::real_double &&
2285 c_type == library_data_t::complex_double)
2286 {
2287 scaling_type = library_data_t::complex_double;
2288 }
2289
2290 std::uint64_t key =
2291 detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2292 switch (key)
2293 {
2294 case detail::get_type_combination_id(
2295 library_data_t::real_float, library_data_t::real_float,
2296 library_data_t::real_float, library_data_t::real_float):
2297 {
2298 detail::gemm_impl<float, float, float, float>(
2299 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2300 break;
2301 }
2302 case detail::get_type_combination_id(
2303 library_data_t::real_double, library_data_t::real_double,
2304 library_data_t::real_double, library_data_t::real_double):
2305 {
2306 detail::gemm_impl<double, double, double, double>(
2307 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2308 break;
2309 }
2310 case detail::get_type_combination_id(
2311 library_data_t::complex_float, library_data_t::complex_float,
2312 library_data_t::complex_float, library_data_t::complex_float):
2313 {
2314 detail::gemm_impl<std::complex<float>, std::complex<float>,
2315 std::complex<float>, std::complex<float>>(
2316 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2317 break;
2318 }
2319 case detail::get_type_combination_id(
2320 library_data_t::complex_double, library_data_t::complex_double,
2321 library_data_t::complex_double, library_data_t::complex_double):
2322 {
2323 detail::gemm_impl<std::complex<double>, std::complex<double>,
2324 std::complex<double>, std::complex<double>>(
2325 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2326 break;
2327 }
2328 case detail::get_type_combination_id(
2329 library_data_t::real_half, library_data_t::real_half,
2330 library_data_t::real_half, library_data_t::real_half):
2331 {
2332 detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2333 sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
2334 lda, b, ldb, beta, c, ldc);
2335 break;
2336 }
2337#ifdef __INTEL_MKL__
2338 case detail::get_type_combination_id(
2339 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2340 library_data_t::real_float, library_data_t::real_float):
2341 {
2342 detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2343 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2344 break;
2345 }
2346 case detail::get_type_combination_id(
2347 library_data_t::real_half, library_data_t::real_half,
2348 library_data_t::real_float, library_data_t::real_float):
2349 {
2350 detail::gemm_impl<sycl::half, sycl::half, float, float>(
2351 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2352 break;
2353 }
2354 case detail::get_type_combination_id(
2355 library_data_t::real_half, library_data_t::real_half,
2356 library_data_t::real_half, library_data_t::real_float):
2357 {
2358 float alpha_value =
2359 dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2360 float beta_value =
2361 dpct::get_value(reinterpret_cast<const float *>(beta), q);
2362 sycl::half alpha_half(alpha_value);
2363 sycl::half beta_half(beta_value);
2364 detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2365 sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
2366 a, lda, b, ldb, &beta_half, c, ldc);
2367 break;
2368 }
2369 case detail::get_type_combination_id(
2370 library_data_t::real_int8, library_data_t::real_int8,
2371 library_data_t::real_float, library_data_t::real_float):
2372 {
2373 detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
2374 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2375 break;
2376 }
2377 case detail::get_type_combination_id(
2378 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2379 library_data_t::real_bfloat16, library_data_t::real_float):
2380 {
2381 detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2382 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2383 break;
2384 }
2385 case detail::get_type_combination_id(
2386 library_data_t::real_int8, library_data_t::real_int8,
2387 library_data_t::real_int32, library_data_t::real_int32):
2388 {
2389 float alpha_float =
2390 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2391 float beta_float =
2392 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2393 detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2394 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2395 break;
2396 }
2397#endif // __INTEL_MKL__
2398 default:
2399 throw std::runtime_error("the combination of data type is unsupported");
2400 }
2401 } // gemm()
2402
2403 /// Computes a batch of matrix-matrix product with general matrices.
2404 /// \param [in] q The queue where the routine should be executed.
2405 /// \param [in] a_trans Specifies the operation applied to A.
2406 /// \param [in] b_trans Specifies the operation applied to B.
2407 /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2408 /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2409 /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2410 /// \param [in] alpha Scaling factor for the matrix-matrix product.
2411 /// \param [in] a Input matrix A.
2412 /// \param [in] a_type Data type of the matrix A.
2413 /// \param [in] lda Leading dimension of A.
2414 /// \param [in] b Input matrix B.
2415 /// \param [in] b_type Data type of the matrix B.
2416 /// \param [in] ldb Leading dimension of B.
2417 /// \param [in] beta Scaling factor for matrix C.
2418 /// \param [in, out] c Input/Output matrix C.
2419 /// \param [in] c_type Data type of the matrix C.
2420 /// \param [in] ldc Leading dimension of C.
2421 /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2422 /// \param [in] scaling_type Data type of the scaling factors.
2423 inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2424 int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2425 const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2426 library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2427 matrix_info_t<float> * matrix_info) {
2428 std::uint64_t key =
2429 detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2430 switch (key)
2431 {
2432 case detail::get_type_combination_id(
2433 library_data_t::real_float, library_data_t::real_float,
2434 library_data_t::real_float, library_data_t::real_float):
2435 {
2436 detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437 beta, c, ldc, batch_size, matrix_info);
2438 break;
2439 }
2440 case detail::get_type_combination_id(
2441 library_data_t::real_double, library_data_t::real_double,
2442 library_data_t::real_double, library_data_t::real_double):
2443 {
2444 detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2445 beta, c, ldc, batch_size, matrix_info);
2446 break;
2447 }
2448 case detail::get_type_combination_id(
2449 library_data_t::real_half, library_data_t::real_half,
2450 library_data_t::real_half, library_data_t::real_half):
2451 {
2452 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2453 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2454 break;
2455 }
2456#ifdef __INTEL_MKL__
2457 case detail::get_type_combination_id(
2458 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2459 library_data_t::real_bfloat16, library_data_t::real_float):
2460 {
2461 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2462 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2463 break;
2464 }
2465 case detail::get_type_combination_id(
2466 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2467 library_data_t::real_float, library_data_t::real_float):
2468 {
2469 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2470 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2471 break;
2472 }
2473#endif
2474 case detail::get_type_combination_id(
2475 library_data_t::real_int8, library_data_t::real_int8,
2476 library_data_t::real_int32, library_data_t::real_int32):
2477 {
2478 float alpha_float =
2479 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2480 float beta_float =
2481 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2482 detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2483 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2484 matrix_info);
2485 break;
2486 }
2487 case detail::get_type_combination_id(
2488 library_data_t::real_int8, library_data_t::real_int8,
2489 library_data_t::real_float, library_data_t::real_float):
2490 {
2491 detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2492 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2493 break;
2494 }
2495 case detail::get_type_combination_id(
2496 library_data_t::real_half, library_data_t::real_half,
2497 library_data_t::real_float, library_data_t::real_float):
2498 {
2499 detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2500 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2501 break;
2502 }
2503 case detail::get_type_combination_id(
2504 library_data_t::real_half, library_data_t::real_half,
2505 library_data_t::real_half, library_data_t::real_float):
2506 {
2507 float alpha_value =
2508 dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2509 float beta_value =
2510 dpct::get_value(reinterpret_cast<const float *>(beta), q);
2511 sycl::half alpha_half(alpha_value);
2512 sycl::half beta_half(beta_value);
2513 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2514 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
2515 break;
2516 }
2517 default:
2518 throw std::runtime_error("the combination of data type is unsupported");
2519 }
2520 }
2521
2522 /// Computes a batch of matrix-matrix product with general matrices.
2523 /// \param [in] q The queue where the routine should be executed.
2524 /// \param [in] a_trans Specifies the operation applied to A.
2525 /// \param [in] b_trans Specifies the operation applied to B.
2526 /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2527 /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2528 /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2529 /// \param [in] alpha Scaling factor for the matrix-matrix product.
2530 /// \param [in] a Input matrix A.
2531 /// \param [in] a_type Data type of the matrix A.
2532 /// \param [in] lda Leading dimension of A.
2533 /// \param [in] stride_a Stride between the different A matrices.
2534 /// \param [in] b Input matrix B.
2535 /// \param [in] b_type Data type of the matrix B.
2536 /// \param [in] ldb Leading dimension of B.
2537 /// \param [in] stride_b Stride between the different B matrices.
2538 /// \param [in] beta Scaling factor for matrix C.
2539 /// \param [in, out] c Input/Output matrix C.
2540 /// \param [in] c_type Data type of the matrix C.
2541 /// \param [in] ldc Leading dimension of C.
2542 /// \param [in] stride_c Stride between the different C matrices.
2543 /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2544 /// \param [in] scaling_type Data type of the scaling factors.
2545 inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2546 int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2547 long long int stride_a, const void * b, library_data_t b_type, int ldb,
2548 long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2549 long long int stride_c, int batch_size, library_data_t scaling_type) {
2550 if (scaling_type == library_data_t::real_float &&
2551 c_type == library_data_t::complex_float)
2552 {
2553 scaling_type = library_data_t::complex_float;
2554 }
2555 else if (scaling_type == library_data_t::real_double &&
2556 c_type == library_data_t::complex_double)
2557 {
2558 scaling_type = library_data_t::complex_double;
2559 }
2560
2561 std::uint64_t key =
2562 detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2563 switch (key)
2564 {
2565 case detail::get_type_combination_id(
2566 library_data_t::real_float, library_data_t::real_float,
2567 library_data_t::real_float, library_data_t::real_float):
2568 {
2569 detail::gemm_batch_impl<float, float, float, float>(
2570 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2571 beta, c, ldc, stride_c, batch_size);
2572 break;
2573 }
2574 case detail::get_type_combination_id(
2575 library_data_t::real_double, library_data_t::real_double,
2576 library_data_t::real_double, library_data_t::real_double):
2577 {
2578 detail::gemm_batch_impl<double, double, double, double>(
2579 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2580 beta, c, ldc, stride_c, batch_size);
2581 break;
2582 }
2583 case detail::get_type_combination_id(
2584 library_data_t::complex_float, library_data_t::complex_float,
2585 library_data_t::complex_float, library_data_t::complex_float):
2586 {
2587 detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2588 std::complex<float>, std::complex<float>>(
2589 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2590 beta, c, ldc, stride_c, batch_size);
2591 break;
2592 }
2593 case detail::get_type_combination_id(
2594 library_data_t::complex_double, library_data_t::complex_double,
2595 library_data_t::complex_double, library_data_t::complex_double):
2596 {
2597 detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2598 std::complex<double>, std::complex<double>>(
2599 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2600 beta, c, ldc, stride_c, batch_size);
2601 break;
2602 }
2603 case detail::get_type_combination_id(
2604 library_data_t::real_half, library_data_t::real_half,
2605 library_data_t::real_half, library_data_t::real_half):
2606 {
2607 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2608 sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2609 a, lda, stride_a, b, ldb, stride_b,
2610 beta, c, ldc, stride_c, batch_size);
2611 break;
2612 }
2613#ifdef __INTEL_MKL__
2614 case detail::get_type_combination_id(
2615 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2616 library_data_t::real_bfloat16, library_data_t::real_float):
2617 {
2618 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2619 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2620 batch_size);
2621 break;
2622 }
2623 case detail::get_type_combination_id(
2624 library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2625 library_data_t::real_float, library_data_t::real_float):
2626 {
2627 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2628 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2629 batch_size);
2630 break;
2631 }
2632#endif
2633 case detail::get_type_combination_id(
2634 library_data_t::real_int8, library_data_t::real_int8,
2635 library_data_t::real_int32, library_data_t::real_int32):
2636 {
2637 detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2638 std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
2639 a, lda, stride_a, b, ldb, stride_b,
2640 beta, c, ldc, stride_c, batch_size);
2641 break;
2642 }
2643 case detail::get_type_combination_id(
2644 library_data_t::real_int8, library_data_t::real_int8,
2645 library_data_t::real_float, library_data_t::real_float):
2646 {
2647 detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2648 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2649 beta, c, ldc, stride_c, batch_size);
2650 break;
2651 }
2652 case detail::get_type_combination_id(
2653 library_data_t::real_half, library_data_t::real_half,
2654 library_data_t::real_float, library_data_t::real_float):
2655 {
2656 detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2657 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2658 beta, c, ldc, stride_c, batch_size);
2659 break;
2660 }
2661 case detail::get_type_combination_id(
2662 library_data_t::real_half, library_data_t::real_half,
2663 library_data_t::real_half, library_data_t::real_float):
2664 {
2665 float alpha_value =
2666 dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2667 float beta_value =
2668 dpct::get_value(reinterpret_cast<const float *>(beta), q);
2669 sycl::half alpha_half(alpha_value);
2670 sycl::half beta_half(beta_value);
2671 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2672 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
2673 &beta_half, c, ldc, stride_c, batch_size);
2674 break;
2675 }
2676 default:
2677 throw std::runtime_error("the combination of data type is unsupported");
2678 }
2679 }
2680
2681 static inline void
2682 async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
2683 size_t from_pitch, size_t x, size_t y,
2684 memcpy_direction direction = automatic,
2685 sycl::queue &q = get_default_queue())
2686 {
2687 detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
2688 direction);
2689 }
2690
2691 using err0 = detail::generic_error_type<struct err0_tag, int>;
2692 using err1 = detail::generic_error_type<struct err1_tag, int>;
2693
2694 static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
2695 detail::dpct_free(ptr, q);
2696 }
2697
2698 /// dpct accessor used as device function parameter.
2699 template <class T, memory_region Memory, size_t Dimension> class accessor;
2700 template <class T, memory_region Memory> class accessor<T, Memory, 3> {
2701 public:
2702 using memory_t = detail::memory_traits<Memory, T>;
2703 using element_t = typename memory_t::element_t;
2704 using pointer_t = typename memory_t::pointer_t;
2705 using accessor_t = typename memory_t::template accessor_t<3>;
2706 accessor(pointer_t data, const sycl::range<3> &in_range)
2707 : _data(data), _range(in_range) {}
2708 template <memory_region M = Memory>
2709 accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2710 : accessor(acc, acc.get_range()) {}
2711 accessor(const accessor_t &acc, const sycl::range<3> &in_range)
2712 : accessor(acc.get_pointer(), in_range) {}
2713 accessor<T, Memory, 2> operator[](size_t index) const {
2714 sycl::range<2> sub(_range.get(1), _range.get(2));
2715 return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
2716 }
2717
2718 pointer_t get_ptr() const { return _data; }
2719
2720 private:
2721 pointer_t _data;
2722 sycl::range<3> _range;
2723 };
2724 template <class T, memory_region Memory> class accessor<T, Memory, 2> {
2725 public:
2726 using memory_t = detail::memory_traits<Memory, T>;
2727 using element_t = typename memory_t::element_t;
2728 using pointer_t = typename memory_t::pointer_t;
2729 using accessor_t = typename memory_t::template accessor_t<2>;
2730 accessor(pointer_t data, const sycl::range<2> &in_range)
2731 : _data(data), _range(in_range) {}
2732 template <memory_region M = Memory>
2733 accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2734 : accessor(acc, acc.get_range()) {}
2735 accessor(const accessor_t &acc, const sycl::range<2> &in_range)
2736 : accessor(acc.get_pointer(), in_range) {}
2737
2738 pointer_t operator[](size_t index) const {
2739 return _data + _range.get(1) * index;
2740 }
2741
2742 pointer_t get_ptr() const { return _data; }
2743
2744 private:
2745 pointer_t _data;
2746 sycl::range<2> _range;
2747 };
2748
2749 namespace detail {
2750 /// Device variable with address space of shared, global or constant.
2751 template <class T, memory_region Memory, size_t Dimension> class device_memory {
2752 public:
2753 using accessor_t =
2754 typename detail::memory_traits<Memory,
2755 T>::template accessor_t<Dimension>;
2756 using value_t = typename detail::memory_traits<Memory, T>::value_t;
2757 using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
2758
2759 device_memory() : device_memory(sycl::range<Dimension>(1)) {}
2760
2761 /// Constructor of 1-D array with initializer list
2762 device_memory(const sycl::range<Dimension> &in_range,
2763 std::initializer_list<value_t> &&init_list)
2764 : device_memory(in_range) {
2765 assert(init_list.size() <= in_range.size());
2766 _host_ptr = (value_t *)std::malloc(_size);
2767 std::memset(_host_ptr, 0, _size);
2768 std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
2769 }
2770
2771 /// Constructor of 2-D array with initializer list
2772 template <size_t D = Dimension>
2773 device_memory(
2774 const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
2775 std::initializer_list<std::initializer_list<value_t>> &&init_list)
2776 : device_memory(in_range) {
2777 assert(init_list.size() <= in_range[0]);
2778 _host_ptr = (value_t *)std::malloc(_size);
2779 std::memset(_host_ptr, 0, _size);
2780 auto tmp_data = _host_ptr;
2781 for (auto sub_list : init_list) {
2782 assert(sub_list.size() <= in_range[1]);
2783 std::memcpy(tmp_data, sub_list.begin(),
2784 sub_list.size() * sizeof(T));
2785 tmp_data += in_range[1];
2786 }
2787 }
2788
2789 /// Constructor with range
2790 device_memory(const sycl::range<Dimension> &range_in)
2791 : _size(range_in.size() * sizeof(T)), _range(range_in),
2792 _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
2793 static_assert(
2794 (Memory == global) || (Memory == constant) || (Memory == shared),
2795 "device memory region should be global, constant or shared");
2796 // Make sure that singleton class mem_mgr and dev_mgr will destruct
2797 // later than this.
2798 detail::mem_mgr::instance();
2799 dev_mgr::instance();
2800 }
2801
2802 /// Constructor with range
2803 template <class... Args>
2804 device_memory(Args... Arguments)
2805 : device_memory(sycl::range<Dimension>(Arguments...)) {}
2806
2807 ~device_memory() {
2808 if (_device_ptr && !_reference)
2809 dpct::dpct_free(_device_ptr);
2810 if (_host_ptr)
2811 std::free(_host_ptr);
2812 }
2813
2814 /// Allocate memory with default queue, and init memory if has initial
2815 /// value.
2816 void init() { init(dpct::get_default_queue()); }
2817 /// Allocate memory with specified queue, and init memory if has initial
2818 /// value.
2819 void init(sycl::queue &q) {
2820 if (_device_ptr)
2821 return;
2822 if (!_size)
2823 return;
2824 allocate_device(q);
2825 if (_host_ptr)
2826 detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
2827 host_to_device);
2828 }
2829
2830 /// The variable is assigned to a device pointer.
2831 void assign(value_t *src, size_t size) {
2832 this->~device_memory();
2833 new (this) device_memory(src, size);
2834 }
2835
2836 /// Get memory pointer of the memory object, which is virtual pointer when
2837 /// usm is not used, and device pointer when usm is used.
2838 value_t *get_ptr() { return get_ptr(get_default_queue()); }
2839 /// Get memory pointer of the memory object, which is virtual pointer when
2840 /// usm is not used, and device pointer when usm is used.
2841 value_t *get_ptr(sycl::queue &q) {
2842 init(q);
2843 return _device_ptr;
2844 }
2845
2846 /// Get the device memory object size in bytes.
2847 size_t get_size() { return _size; }
2848
2849 template <size_t D = Dimension>
2850 typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
2851 init();
2852 return _device_ptr[index];
2853 }
2854
2855 /// Get dpct::accessor with dimension info for the device memory object
2856 /// when usm is used and dimension is greater than 1.
2857 template <size_t D = Dimension>
2858 typename std::enable_if<D != 1, dpct_accessor_t>::type
2859 get_access([[maybe_unused]] sycl::handler &cgh) {
2860 return dpct_accessor_t((T *)_device_ptr, _range);
2861 }
2862
2863 private:
2864 device_memory(value_t *memory_ptr, size_t size)
2865 : _size(size), _range(size / sizeof(T)), _reference(true),
2866 _device_ptr(memory_ptr) {}
2867
2868 void allocate_device(sycl::queue &q) {
2869 #ifndef DPCT_USM_LEVEL_NONE
2870 if (Memory == shared) {
2871 _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
2872 q.get_context());
2873 return;
2874 }
2875 #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
2876 if (Memory == constant) {
2877 _device_ptr = (value_t *)sycl::malloc_device(
2878 _size, q.get_device(), q.get_context(),
2879 sycl::ext::oneapi::property::usm::device_read_only());
2880 return;
2881 }
2882 #endif
2883 #endif
2884 _device_ptr = (value_t *)detail::dpct_malloc(_size, q);
2885 }
2886
2887 size_t _size;
2888 sycl::range<Dimension> _range;
2889 bool _reference;
2890 value_t *_host_ptr;
2891 value_t *_device_ptr;
2892 };
2893 template <class T, memory_region Memory>
2894 class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
2895 public:
2896 using base = device_memory<T, Memory, 1>;
2897 using value_t = typename base::value_t;
2898 using accessor_t =
2899 typename detail::memory_traits<Memory, T>::template accessor_t<0>;
2900
2901 /// Constructor with initial value.
2902 device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
2903
2904 /// Default constructor
2905 device_memory() : base(1) {}
2906 };
2907 } // namespace detail
2908
2909 template <class T, size_t Dimension>
2910 using global_memory = detail::device_memory<T, global, Dimension>;
2911 template <class T, size_t Dimension>
2912 using constant_memory = detail::device_memory<T, constant, Dimension>;
2913 template <class T, size_t Dimension>
2914 using shared_memory = detail::device_memory<T, shared, Dimension>;
2915
2916
2917 template <typename T,
2918 sycl::access::address_space addressSpace =
2919 sycl::access::address_space::global_space,
2920 sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2921 sycl::memory_scope memoryScope = sycl::memory_scope::device>
2922 inline T atomic_fetch_add(T *addr, T operand) {
2923 auto atm =
2924 sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2925 return atm.fetch_add(operand);
2926 }
2927
2928 template <sycl::access::address_space addressSpace =
2929 sycl::access::address_space::global_space,
2930 sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2931 sycl::memory_scope memoryScope = sycl::memory_scope::device,
2932 typename T1, typename T2>
2933 inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2934 auto atm =
2935 sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2936 return atm.fetch_add(operand);
2937 }
2938
2939 template <typename T, sycl::access::address_space addressSpace =
2940 sycl::access::address_space::global_space>
2941 inline T atomic_fetch_add(T *addr, T operand,
2942 sycl::memory_order memoryOrder) {
2943 switch (memoryOrder) {
2944 case sycl::memory_order::relaxed:
2945 return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2946 sycl::memory_scope::device>(addr, operand);
2947 case sycl::memory_order::acq_rel:
2948 return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2949 sycl::memory_scope::device>(addr, operand);
2950 case sycl::memory_order::seq_cst:
2951 return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2952 sycl::memory_scope::device>(addr, operand);
2953 default:
2954 assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2955 "atomics are: sycl::memory_order::relaxed, "
2956 "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2957 }
2958 }
2959
2960 template <sycl::access::address_space addressSpace =
2961 sycl::access::address_space::global_space,
2962 typename T1, typename T2>
2963 inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2964 sycl::memory_order memoryOrder) {
2965 atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2966 }
2967
2968 inline unsigned int byte_level_permute(
2969 unsigned int a, unsigned int b, unsigned int s) {
2970 unsigned int ret;
2971 ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
2972 (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
2973 << 8) |
2974 (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
2975 << 16) |
2976 (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
2977 << 24);
2978 return ret;
2979 }
2980
2981 inline uint32_t byte_level_permute_custom(
2982 uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
2983 constexpr uint16_t lookup[6][4] = {
2984 {0x3210, 0x4321, 0x5432, 0x6543}, // Forward 4-byte extract
2985 {0x5670, 0x6701, 0x7012, 0x0123}, // Backward 4-byte extract
2986 {0x0000, 0x1111, 0x2222, 0x3333}, // Replicate 8-bit values
2987 {0x3210, 0x3211, 0x3222, 0x3333}, // Edge clamp left
2988 {0x0000, 0x1110, 0x2210, 0x3210}, // Edge clamp right
2989 {0x1010, 0x3232, 0x1010, 0x3232} // Replicate 16-bit values
2990 };
2991
2992 if (mode >= 1 && mode <= 6) {
2993 return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
2994 } else if (!mode) {
2995 return byte_level_permute(low32, high32, sel);
2996 }
2997 return 0;
2998 }
2999
3000} // COPY from DPCT head files
3001
3002#endif // GGML_SYCL_DPCT_HELPER_HPP