1//
   2// MIT license
   3// Copyright (C) 2024 Intel Corporation
   4// SPDX-License-Identifier: MIT
   5//
   6
   7//
   8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
   9// See https://llvm.org/LICENSE.txt for license information.
  10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11//
  12
  13#ifndef GGML_SYCL_DPCT_HELPER_HPP
  14#define GGML_SYCL_DPCT_HELPER_HPP
  15
  16#include <sycl/sycl.hpp>
  17#include <sycl/half_type.hpp>
  18#include <oneapi/mkl.hpp>
  19
  20#include <map>
  21
  22#include "ggml.h"
  23
  24#if defined(__linux__)
  25#include <sys/mman.h>
  26#elif defined(_WIN64)
  27#ifndef NOMINMAX
  28#define NOMINMAX
  29#endif
  30#include <windows.h>
  31#else
  32#error "Only support Windows and Linux."
  33#endif
  34
  35#if defined(__linux__)
  36#include <unistd.h>
  37#include <sys/syscall.h>
  38#endif
  39#if defined(_WIN64)
  40#ifndef NOMINMAX
  41#define NOMINMAX
  42#endif
  43#include <windows.h>
  44#endif
  45
  46#define DPCT_COMPATIBILITY_TEMP (900)
  47
  48#if defined(_MSC_VER)
  49#define __dpct_align__(n) __declspec(align(n))
  50#define __dpct_inline__ __forceinline
  51#else
  52#define __dpct_align__(n) __attribute__((aligned(n)))
  53#define __dpct_inline__ __inline__ __attribute__((always_inline))
  54#endif
  55
  56#if defined(_MSC_VER)
  57#define __dpct_noinline__ __declspec(noinline)
  58#else
  59#define __dpct_noinline__ __attribute__((noinline))
  60#endif
  61
  62inline std::string get_device_type_name(const sycl::device &Device) {
  63    auto DeviceType = Device.get_info<sycl::info::device::device_type>();
  64    switch (DeviceType) {
  65    case sycl::info::device_type::cpu:
  66        return "cpu";
  67    case sycl::info::device_type::gpu:
  68        return "gpu";
  69    case sycl::info::device_type::host:
  70        return "host";
  71    case sycl::info::device_type::accelerator:
  72        return "acc";
  73    default:
  74        return "unknown";
  75    }
  76}
  77
  78inline std::string get_device_backend_and_type(const sycl::device &device) {
  79    std::stringstream device_type;
  80    sycl::backend backend = device.get_backend();
  81    device_type <<  backend << ":" << get_device_type_name(device);
  82    return device_type.str();
  83}
  84
  85template <typename Ts> struct matrix_info_t {
  86    oneapi::mkl::transpose transpose_info[2];
  87    Ts                     value_info[2];
  88    std::int64_t           size_info[3];
  89    std::int64_t           ld_info[3];
  90    std::int64_t           groupsize_info;
  91};
  92
  93namespace dpct
  94{
  95    typedef sycl::queue *queue_ptr;
  96    typedef sycl::event *event_ptr;
  97    typedef char *device_ptr;
  98    typedef uint8_t byte_t;
  99    typedef sycl::buffer<byte_t> buffer_t;
 100
 101    /// SYCL default exception handler
 102    inline auto exception_handler = [](sycl::exception_list exceptions)
 103    {
 104        for (std::exception_ptr const &e : exceptions)
 105        {
 106            try
 107            {
 108                std::rethrow_exception(e);
 109            }
 110            catch (sycl::exception const &e)
 111            {
 112                std::cerr << "Caught asynchronous SYCL exception:" << std::endl
 113                          << e.what() << std::endl
 114                          << "Exception caught at file:" << __FILE__
 115                          << ", line:" << __LINE__ << std::endl;
 116            }
 117        }
 118    };
 119
 120    enum error_code
 121    {
 122        success = 0,
 123        default_error = 999
 124    };
 125
 126    enum memcpy_direction
 127    {
 128        host_to_host,
 129        host_to_device,
 130        device_to_host,
 131        device_to_device,
 132        automatic
 133    };
 134
 135    enum memory_region
 136    {
 137        global = 0, // device global memory
 138        constant,   // device constant memory
 139        local,      // device local memory
 140        shared,     // memory which can be accessed by host and device
 141    };
 142
 143    enum class library_data_t : unsigned char
 144    {
 145        real_float = 0,
 146        complex_float,
 147        real_double,
 148        complex_double,
 149        real_half,
 150        complex_half,
 151        real_bfloat16,
 152        complex_bfloat16,
 153        real_int4,
 154        complex_int4,
 155        real_uint4,
 156        complex_uint4,
 157        real_int8,
 158        complex_int8,
 159        real_uint8,
 160        complex_uint8,
 161        real_int16,
 162        complex_int16,
 163        real_uint16,
 164        complex_uint16,
 165        real_int32,
 166        complex_int32,
 167        real_uint32,
 168        complex_uint32,
 169        real_int64,
 170        complex_int64,
 171        real_uint64,
 172        complex_uint64,
 173        real_int8_4,
 174        real_int8_32,
 175        real_uint8_4,
 176        library_data_t_size
 177    };
 178
 179    template <typename T>
 180    struct DataType
 181    {
 182        using T2 = T;
 183    };
 184    template <typename T>
 185    struct DataType<sycl::vec<T, 2>>
 186    {
 187        using T2 = std::complex<T>;
 188    };
 189
 190    static void destroy_event(event_ptr event)
 191    {
 192        delete event;
 193    }
 194
 195    static inline unsigned int get_tid()
 196    {
 197#if defined(__linux__)
 198        return syscall(SYS_gettid);
 199#elif defined(_WIN64)
 200        return GetCurrentThreadId();
 201#else
 202#error "Only support Windows and Linux."
 203#endif
 204    }
 205
 206    namespace detail
 207    {
 208        static void get_version(const sycl::device &dev, int &major, int &minor)
 209        {
 210            // Version string has the following format:
 211            // a. OpenCL<space><major.minor><space><vendor-specific-information>
 212            // b. <major.minor>
 213            // c. <AmdGcnArchName> e.g gfx1030
 214            std::string ver;
 215            ver = dev.get_info<sycl::info::device::version>();
 216            std::string::size_type i = 0;
 217            while (i < ver.size()) {
 218              if (isdigit(ver[i]))
 219                break;
 220              i++;
 221            }
 222            major = std::stoi(&(ver[i]));
 223            while (i < ver.size()) {
 224              if (ver[i] == '.')
 225                break;
 226              i++;
 227            }
 228            if (i < ver.size()) {
 229              // a. and b.
 230              i++;
 231              minor = std::stoi(&(ver[i]));
 232            } else {
 233              // c.
 234              minor = 0;
 235            }
 236        }
 237
 238        template <typename tag, typename T>
 239        class generic_error_type
 240        {
 241        public:
 242            generic_error_type() = default;
 243            generic_error_type(T value) : value{value} {}
 244            operator T() const { return value; }
 245
 246        private:
 247            T value;
 248        };
 249
 250    } // namespace detail
 251
 252    // COPY from DPCT head files
 253    /// dim3 is used to store 3 component dimensions.
 254    class dim3 {
 255        public:
 256        unsigned x, y, z;
 257
 258        constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
 259            : x(x), y(y), z(z) {}
 260
 261        dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
 262
 263        operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
 264    }; // namespace dim3
 265
 266    inline dim3 operator*(const dim3 &a, const dim3 &b) {
 267    return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
 268    }
 269    // COPY from DPCT head files
 270
 271
 272    /// Pitched 2D/3D memory data.
 273    class pitched_data
 274    {
 275    public:
 276        pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
 277        pitched_data(void *data, size_t pitch, size_t x, size_t y)
 278            : _data(data), _pitch(pitch), _x(x), _y(y) {}
 279
 280        void *get_data_ptr() { return _data; }
 281        void set_data_ptr(void *data) { _data = data; }
 282
 283        size_t get_pitch() { return _pitch; }
 284        void set_pitch(size_t pitch) { _pitch = pitch; }
 285
 286        size_t get_x() { return _x; }
 287        void set_x(size_t x) { _x = x; }
 288
 289        size_t get_y() { return _y; }
 290        void set_y(size_t y) { _y = y; }
 291
 292    private:
 293        void *_data;
 294        size_t _pitch, _x, _y;
 295    };
 296
 297    class device_info
 298    {
 299    public:
 300        // get interface
 301        const char *get_name() const { return _name; }
 302        char *get_name() { return _name; }
 303        template <typename WorkItemSizesTy = sycl::range<3>,
 304                  std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
 305                                       std::is_same_v<WorkItemSizesTy, int *>,
 306                                   int> = 0>
 307        auto get_max_work_item_sizes() const
 308        {
 309            if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
 310                return sycl::range<3>(_max_work_item_sizes_i[0],
 311                                      _max_work_item_sizes_i[1],
 312                                      _max_work_item_sizes_i[2]);
 313            else
 314            {
 315                return _max_work_item_sizes_i;
 316            }
 317        }
 318        template <typename WorkItemSizesTy = sycl::range<3>,
 319                  std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
 320                                       std::is_same_v<WorkItemSizesTy, int *>,
 321                                   int> = 0>
 322        auto get_max_work_item_sizes()
 323        {
 324            if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
 325                return sycl::range<3>(_max_work_item_sizes_i[0],
 326                                      _max_work_item_sizes_i[1],
 327                                      _max_work_item_sizes_i[2]);
 328            else
 329            {
 330                return _max_work_item_sizes_i;
 331            }
 332        }
 333        bool get_host_unified_memory() const { return _host_unified_memory; }
 334        int get_major_version() const { return _major; }
 335        int get_minor_version() const { return _minor; }
 336        int get_integrated() const { return _integrated; }
 337        int get_max_clock_frequency() const { return _frequency; }
 338        int get_max_compute_units() const { return _max_compute_units; }
 339        int get_max_work_group_size() const { return _max_work_group_size; }
 340        int get_max_sub_group_size() const { return _max_sub_group_size; }
 341        int get_max_work_items_per_compute_unit() const
 342        {
 343            return _max_work_items_per_compute_unit;
 344        }
 345        int get_max_register_size_per_work_group() const
 346        {
 347            return _max_register_size_per_work_group;
 348        }
 349        template <typename NDRangeSizeTy = size_t *,
 350                  std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
 351                                       std::is_same_v<NDRangeSizeTy, int *>,
 352                                   int> = 0>
 353        auto get_max_nd_range_size() const
 354        {
 355            if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
 356                return _max_nd_range_size;
 357            else
 358                return _max_nd_range_size_i;
 359        }
 360        template <typename NDRangeSizeTy = size_t *,
 361                  std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
 362                                       std::is_same_v<NDRangeSizeTy, int *>,
 363                                   int> = 0>
 364        auto get_max_nd_range_size()
 365        {
 366            if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
 367                return _max_nd_range_size;
 368            else
 369                return _max_nd_range_size_i;
 370        }
 371        size_t get_global_mem_size() const { return _global_mem_size; }
 372        size_t get_local_mem_size() const { return _local_mem_size; }
 373        size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
 374        /// Returns the maximum clock rate of device's global memory in kHz. If
 375        /// compiler does not support this API then returns default value 3200000 kHz.
 376        unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
 377        /// Returns the maximum bus width between device and memory in bits. If
 378        /// compiler does not support this API then returns default value 64 bits.
 379        unsigned int get_memory_bus_width() const { return _memory_bus_width; }
 380        uint32_t get_device_id() const { return _device_id; }
 381        std::array<unsigned char, 16> get_uuid() const { return _uuid; }
 382        /// Returns global memory cache size in bytes.
 383        unsigned int get_global_mem_cache_size() const
 384        {
 385            return _global_mem_cache_size;
 386        }
 387
 388        // set interface
 389        void set_name(const char *name)
 390        {
 391            size_t length = strlen(name);
 392            if (length < 256)
 393            {
 394                std::memcpy(_name, name, length + 1);
 395            }
 396            else
 397            {
 398                std::memcpy(_name, name, 255);
 399                _name[255] = '\0';
 400            }
 401        }
 402        void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
 403        {
 404            for (int i = 0; i < 3; ++i)
 405                _max_work_item_sizes_i[i] = max_work_item_sizes[i];
 406        }
 407        [[deprecated]] void
 408        set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
 409        {
 410            for (int i = 0; i < 3; ++i)
 411            {
 412                _max_work_item_sizes_i[i] = max_work_item_sizes[i];
 413            }
 414        }
 415        void set_host_unified_memory(bool host_unified_memory)
 416        {
 417            _host_unified_memory = host_unified_memory;
 418        }
 419        void set_major_version(int major) { _major = major; }
 420        void set_minor_version(int minor) { _minor = minor; }
 421        void set_integrated(int integrated) { _integrated = integrated; }
 422        void set_max_clock_frequency(int frequency) { _frequency = frequency; }
 423        void set_max_compute_units(int max_compute_units)
 424        {
 425            _max_compute_units = max_compute_units;
 426        }
 427        void set_global_mem_size(size_t global_mem_size)
 428        {
 429            _global_mem_size = global_mem_size;
 430        }
 431        void set_local_mem_size(size_t local_mem_size)
 432        {
 433            _local_mem_size = local_mem_size;
 434        }
 435        void set_max_mem_alloc_size(size_t max_mem_alloc_size)
 436        {
 437            _max_mem_alloc_size = max_mem_alloc_size;
 438        }
 439        void set_max_work_group_size(int max_work_group_size)
 440        {
 441            _max_work_group_size = max_work_group_size;
 442        }
 443        void set_max_sub_group_size(int max_sub_group_size)
 444        {
 445            _max_sub_group_size = max_sub_group_size;
 446        }
 447        void
 448        set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
 449        {
 450            _max_work_items_per_compute_unit = max_work_items_per_compute_unit;
 451        }
 452        void set_max_nd_range_size(int max_nd_range_size[])
 453        {
 454            for (int i = 0; i < 3; i++)
 455            {
 456                _max_nd_range_size[i] = max_nd_range_size[i];
 457                _max_nd_range_size_i[i] = max_nd_range_size[i];
 458            }
 459        }
 460        void set_memory_clock_rate(unsigned int memory_clock_rate)
 461        {
 462            _memory_clock_rate = memory_clock_rate;
 463        }
 464        void set_memory_bus_width(unsigned int memory_bus_width)
 465        {
 466            _memory_bus_width = memory_bus_width;
 467        }
 468        void
 469        set_max_register_size_per_work_group(int max_register_size_per_work_group)
 470        {
 471            _max_register_size_per_work_group = max_register_size_per_work_group;
 472        }
 473        void set_device_id(uint32_t device_id)
 474        {
 475            _device_id = device_id;
 476        }
 477        void set_uuid(std::array<unsigned char, 16> uuid)
 478        {
 479            _uuid = std::move(uuid);
 480        }
 481        void set_global_mem_cache_size(unsigned int global_mem_cache_size)
 482        {
 483            _global_mem_cache_size = global_mem_cache_size;
 484        }
 485
 486    private:
 487        char _name[256];
 488        int _max_work_item_sizes_i[3];
 489        bool _host_unified_memory = false;
 490        int _major;
 491        int _minor;
 492        int _integrated = 0;
 493        int _frequency;
 494        // Set estimated value 3200000 kHz as default value.
 495        unsigned int _memory_clock_rate = 3200000;
 496        // Set estimated value 64 bits as default value.
 497        unsigned int _memory_bus_width = 64;
 498        unsigned int _global_mem_cache_size;
 499        int _max_compute_units;
 500        int _max_work_group_size;
 501        int _max_sub_group_size;
 502        int _max_work_items_per_compute_unit;
 503        int _max_register_size_per_work_group;
 504        size_t _global_mem_size;
 505        size_t _local_mem_size;
 506        size_t _max_mem_alloc_size;
 507        size_t _max_nd_range_size[3];
 508        int _max_nd_range_size_i[3];
 509        uint32_t _device_id;
 510        std::array<unsigned char, 16> _uuid;
 511    };
 512
 513    static int get_major_version(const sycl::device &dev)
 514    {
 515        int major, minor;
 516        detail::get_version(dev, major, minor);
 517        return major;
 518    }
 519
 520    static int get_minor_version(const sycl::device &dev)
 521    {
 522        int major, minor;
 523        detail::get_version(dev, major, minor);
 524        return minor;
 525    }
 526
 527    static void get_device_info(device_info &out, const sycl::device &dev)
 528    {
 529        device_info prop;
 530        prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
 531
 532        int major, minor;
 533        detail::get_version(dev, major, minor);
 534        prop.set_major_version(major);
 535        prop.set_minor_version(minor);
 536
 537        prop.set_max_work_item_sizes(
 538#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
 539            // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
 540            // is an enum class element
 541            dev.get_info<sycl::info::device::max_work_item_sizes>());
 542#else
 543            // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
 544            // an int
 545            dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
 546#endif
 547        prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
 548
 549        prop.set_max_clock_frequency(
 550            dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
 551
 552        prop.set_max_compute_units(
 553            dev.get_info<sycl::info::device::max_compute_units>());
 554        prop.set_max_work_group_size(
 555            dev.get_info<sycl::info::device::max_work_group_size>());
 556        prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
 557        prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
 558        prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
 559
 560#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
 561        if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
 562        {
 563            unsigned int tmp =
 564                dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
 565            if (tmp != 0)
 566                prop.set_memory_clock_rate(1000 * tmp);
 567        }
 568        if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
 569        {
 570            prop.set_memory_bus_width(
 571                dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
 572        }
 573        if (dev.has(sycl::aspect::ext_intel_device_id))
 574        {
 575            prop.set_device_id(
 576                dev.get_info<sycl::ext::intel::info::device::device_id>());
 577        }
 578        if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
 579        {
 580            prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
 581        }
 582#elif defined(_MSC_VER) && !defined(__clang__)
 583#pragma message("get_device_info: querying memory_clock_rate and \
 584        memory_bus_width are not supported by the compiler used. \
 585        Use 3200000 kHz as memory_clock_rate default value. \
 586        Use 64 bits as memory_bus_width default value.")
 587#else
 588#warning "get_device_info: querying memory_clock_rate and \
 589        memory_bus_width are not supported by the compiler used. \
 590        Use 3200000 kHz as memory_clock_rate default value. \
 591        Use 64 bits as memory_bus_width default value."
 592#endif
 593
 594        size_t max_sub_group_size = 1;
 595        std::vector<size_t> sub_group_sizes =
 596            dev.get_info<sycl::info::device::sub_group_sizes>();
 597
 598        for (const auto &sub_group_size : sub_group_sizes)
 599        {
 600            if (max_sub_group_size < sub_group_size)
 601                max_sub_group_size = sub_group_size;
 602        }
 603
 604        prop.set_max_sub_group_size(max_sub_group_size);
 605
 606        prop.set_max_work_items_per_compute_unit(
 607            dev.get_info<sycl::info::device::max_work_group_size>());
 608        int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
 609        prop.set_max_nd_range_size(max_nd_range_size);
 610
 611        // Estimates max register size per work group, feel free to update the value
 612        // according to device properties.
 613        prop.set_max_register_size_per_work_group(65536);
 614
 615        prop.set_global_mem_cache_size(
 616            dev.get_info<sycl::info::device::global_mem_cache_size>());
 617        out = prop;
 618    }
 619
 620    /// dpct device extension
 621    class device_ext : public sycl::device {
 622      typedef std::mutex mutex_type;
 623
 624     public:
 625      device_ext() : sycl::device() {}
 626      ~device_ext() {
 627        std::lock_guard<mutex_type> lock(m_mutex);
 628        clear_queues();
 629      }
 630      device_ext(const sycl::device &base) : sycl::device(base) {
 631        std::lock_guard<mutex_type> lock(m_mutex);
 632        init_queues();
 633      }
 634
 635      int is_native_atomic_supported() { return 0; }
 636      int get_major_version() const { return dpct::get_major_version(*this); }
 637
 638      int get_minor_version() const { return dpct::get_minor_version(*this); }
 639
 640      int get_max_compute_units() const {
 641        return get_device_info().get_max_compute_units();
 642      }
 643
 644      /// Return the maximum clock frequency of this device in KHz.
 645      int get_max_clock_frequency() const {
 646        return get_device_info().get_max_clock_frequency();
 647      }
 648
 649      int get_integrated() const { return get_device_info().get_integrated(); }
 650
 651      int get_max_sub_group_size() const {
 652        return get_device_info().get_max_sub_group_size();
 653      }
 654
 655      int get_max_register_size_per_work_group() const {
 656        return get_device_info().get_max_register_size_per_work_group();
 657      }
 658
 659      int get_max_work_group_size() const {
 660        return get_device_info().get_max_work_group_size();
 661      }
 662
 663      int get_mem_base_addr_align() const {
 664        return get_info<sycl::info::device::mem_base_addr_align>();
 665      }
 666
 667      size_t get_global_mem_size() const {
 668        return get_device_info().get_global_mem_size();
 669      }
 670
 671      size_t get_max_mem_alloc_size() const {
 672        return get_device_info().get_max_mem_alloc_size();
 673      }
 674
 675      /// Get the number of bytes of free and total memory on the SYCL device.
 676      /// \param [out] free_memory The number of bytes of free memory on the
 677      /// SYCL device. \param [out] total_memory The number of bytes of total
 678      /// memory on the SYCL device.
 679      void get_memory_info(size_t &free_memory, size_t &total_memory) {
 680        total_memory = get_device_info().get_global_mem_size();
 681        const char *warning_info =
 682            "get_memory_info: [warning] ext_intel_free_memory is not "
 683            "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
 684            "use total memory as free memory";
 685#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
 686        if (!has(sycl::aspect::ext_intel_free_memory)) {
 687          std::cerr << warning_info << std::endl;
 688          free_memory = total_memory;
 689        } else {
 690          free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
 691        }
 692#else
 693        std::cerr << warning_info << std::endl;
 694        free_memory = total_memory;
 695#if defined(_MSC_VER) && !defined(__clang__)
 696#pragma message("Querying the number of bytes of free memory is not supported")
 697#else
 698#warning "Querying the number of bytes of free memory is not supported"
 699#endif
 700#endif
 701      }
 702
 703      void get_device_info(device_info &out) const {
 704        dpct::get_device_info(out, *this);
 705      }
 706
 707      device_info get_device_info() const {
 708        device_info prop;
 709        dpct::get_device_info(prop, *this);
 710        return prop;
 711      }
 712
 713      void reset() {
 714        std::lock_guard<mutex_type> lock(m_mutex);
 715        clear_queues();
 716        init_queues();
 717      }
 718
 719      sycl::queue &in_order_queue() { return _q_in_order; }
 720
 721      sycl::queue &out_of_order_queue() { return _q_out_of_order; }
 722
 723      sycl::queue &default_queue() { return in_order_queue(); }
 724
 725      void queues_wait_and_throw() {
 726        std::unique_lock<mutex_type> lock(m_mutex);
 727        lock.unlock();
 728        for (auto &q : _queues) {
 729            q.wait_and_throw();
 730        }
 731        // Guard the destruct of current_queues to make sure the ref count is
 732        // safe.
 733        lock.lock();
 734      }
 735
 736      sycl::queue create_queue(bool enable_exception_handler = false) {
 737        return create_in_order_queue(enable_exception_handler);
 738      }
 739
 740      sycl::queue create_queue(sycl::device device,
 741                               bool enable_exception_handler = false) {
 742        return create_in_order_queue(device, enable_exception_handler);
 743      }
 744
 745      sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
 746        std::lock_guard<mutex_type> lock(m_mutex);
 747        return create_queue_impl(enable_exception_handler,
 748                                 sycl::property::queue::in_order());
 749      }
 750
 751      sycl::queue create_in_order_queue(sycl::device device,
 752                                        bool enable_exception_handler = false) {
 753        std::lock_guard<mutex_type> lock(m_mutex);
 754        return create_queue_impl(device, enable_exception_handler,
 755                                 sycl::property::queue::in_order());
 756      }
 757
 758      sycl::queue create_out_of_order_queue(
 759          bool enable_exception_handler = false) {
 760        std::lock_guard<mutex_type> lock(m_mutex);
 761        return create_queue_impl(enable_exception_handler);
 762      }
 763
 764      void destroy_queue(sycl::queue queue) {
 765        std::lock_guard<mutex_type> lock(m_mutex);
 766        _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
 767                                    [=](const sycl::queue &q) -> bool
 768                                    {
 769                                        return q == queue;
 770                                    }),
 771                    _queues.end());
 772      }
 773      void set_saved_queue(sycl::queue q) {
 774        std::lock_guard<mutex_type> lock(m_mutex);
 775        _saved_queue = q;
 776      }
 777      sycl::queue get_saved_queue() const {
 778        std::lock_guard<mutex_type> lock(m_mutex);
 779        return _saved_queue;
 780      }
 781
 782     private:
 783      void clear_queues() { _queues.clear(); }
 784
 785      void init_queues() {
 786        _q_in_order =
 787            create_queue_impl(true, sycl::property::queue::in_order());
 788        _q_out_of_order = create_queue_impl(true);
 789        _saved_queue = default_queue();
 790      }
 791
 792      /// Caller should acquire resource \p m_mutex before calling this
 793      /// function.
 794      template <class... Properties>
 795      sycl::queue create_queue_impl(bool enable_exception_handler,
 796                                    Properties... properties) {
 797        sycl::async_handler eh = {};
 798        if (enable_exception_handler) {
 799          eh = exception_handler;
 800        }
 801        _queues.push_back(sycl::queue(
 802            *this, eh,
 803            sycl::property_list(
 804#ifdef DPCT_PROFILING_ENABLED
 805                sycl::property::queue::enable_profiling(),
 806#endif
 807                properties...)));
 808
 809        return _queues.back();
 810      }
 811
 812      template <class... Properties>
 813      sycl::queue create_queue_impl(sycl::device device,
 814                                    bool enable_exception_handler,
 815                                    Properties... properties) {
 816        sycl::async_handler eh = {};
 817        if (enable_exception_handler) {
 818          eh = exception_handler;
 819        }
 820        _queues.push_back(sycl::queue(
 821            device, eh,
 822                        sycl::property_list(
 823#ifdef DPCT_PROFILING_ENABLED
 824                            sycl::property::queue::enable_profiling(),
 825#endif
 826                            properties...)));
 827
 828        return _queues.back();
 829      }
 830
 831      void get_version(int &major, int &minor) const {
 832        detail::get_version(*this, major, minor);
 833      }
 834      sycl::queue _q_in_order, _q_out_of_order;
 835      sycl::queue _saved_queue;
 836      std::vector<sycl::queue> _queues;
 837      mutable mutex_type m_mutex;
 838    };
 839
 840
 841    /// device manager
 842    class dev_mgr
 843    {
 844    public:
 845        device_ext &current_device()
 846        {
 847            unsigned int dev_id = current_device_id();
 848            check_id(dev_id);
 849            return *_devs[dev_id];
 850        }
 851        device_ext &cpu_device() const
 852        {
 853            std::lock_guard<std::recursive_mutex> lock(m_mutex);
 854            if (_cpu_device == -1)
 855            {
 856                throw std::runtime_error("no valid cpu device");
 857            }
 858            else
 859            {
 860                return *_devs[_cpu_device];
 861            }
 862        }
 863        device_ext &get_device(unsigned int id) const
 864        {
 865            std::lock_guard<std::recursive_mutex> lock(m_mutex);
 866            check_id(id);
 867            return *_devs[id];
 868        }
 869        unsigned int current_device_id() const
 870        {
 871            std::lock_guard<std::recursive_mutex> lock(m_mutex);
 872            auto it = _thread2dev_map.find(get_tid());
 873            if (it != _thread2dev_map.end())
 874                return it->second;
 875            return DEFAULT_DEVICE_ID;
 876        }
 877
 878        /// Select device with a device ID.
 879        /// \param [in] id The id of the device which can
 880        /// be obtained through get_device_id(const sycl::device).
 881        void select_device(unsigned int id)
 882        {
 883            std::lock_guard<std::recursive_mutex> lock(m_mutex);
 884            check_id(id);
 885            _thread2dev_map[get_tid()] = id;
 886        }
 887        unsigned int device_count() { return _devs.size(); }
 888
 889        unsigned int get_device_id(const sycl::device &dev)
 890        {
 891            unsigned int id = 0;
 892            for (auto &dev_item : _devs)
 893            {
 894                if (*dev_item == dev)
 895                {
 896                    return id;
 897                }
 898                id++;
 899            }
 900            return -1;
 901        }
 902
 903        inline std::string get_preferred_gpu_platform_name() {
 904            std::string result;
 905
 906            std::string filter = "";
 907            char* env = getenv("ONEAPI_DEVICE_SELECTOR");
 908            if (env) {
 909                if (std::strstr(env, "level_zero")) {
 910                    filter = "level-zero";
 911                }
 912                else if (std::strstr(env, "opencl")) {
 913                    filter = "opencl";
 914                }
 915                else if (std::strstr(env, "cuda")) {
 916                    filter = "cuda";
 917                }
 918                else if (std::strstr(env, "hip")) {
 919                    filter = "hip";
 920                }
 921                else {
 922                    throw std::runtime_error("invalid device filter: " + std::string(env));
 923                }
 924            } else {
 925                auto default_device = sycl::device(sycl::default_selector_v);
 926                auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
 927
 928                if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
 929                    filter = "level-zero";
 930                }
 931                else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
 932                    filter = "cuda";
 933                }
 934                else if (std::strstr(default_platform_name.c_str(), "HIP")) {
 935                    filter = "hip";
 936                }
 937            }
 938
 939            auto platform_list = sycl::platform::get_platforms();
 940
 941            for (const auto& platform : platform_list) {
 942                auto devices = platform.get_devices();
 943                auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
 944                    return d.is_gpu();
 945                });
 946
 947                if (gpu_dev == devices.end()) {
 948                    // cout << "platform [" << platform_name
 949                    //      << "] does not contain GPU devices, skipping\n";
 950                    continue;
 951                }
 952
 953                auto platform_name = platform.get_info<sycl::info::platform::name>();
 954                std::string platform_name_low_case;
 955                platform_name_low_case.resize(platform_name.size());
 956
 957                std::transform(
 958                    platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
 959
 960                if (platform_name_low_case.find(filter) == std::string::npos) {
 961                    // cout << "platform [" << platform_name
 962                    //      << "] does not match with requested "
 963                    //      << filter << ", skipping\n";
 964                    continue;
 965                }
 966
 967                result = platform_name;
 968            }
 969
 970            if (result.empty())
 971                throw std::runtime_error("can not find preferred GPU platform");
 972
 973            return result;
 974        }
 975
 976        template <class DeviceSelector>
 977        std::enable_if_t<
 978            std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
 979        select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
 980        {
 981            sycl::device selected_device = sycl::device(selector);
 982            unsigned int selected_device_id = get_device_id(selected_device);
 983            select_device(selected_device_id);
 984        }
 985
 986        /// Returns the instance of device manager singleton.
 987        static dev_mgr &instance()
 988        {
 989            static dev_mgr d_m;
 990            return d_m;
 991        }
 992        dev_mgr(const dev_mgr &) = delete;
 993        dev_mgr &operator=(const dev_mgr &) = delete;
 994        dev_mgr(dev_mgr &&) = delete;
 995        dev_mgr &operator=(dev_mgr &&) = delete;
 996
 997    private:
 998        mutable std::recursive_mutex m_mutex;
 999        static bool compare_dev(sycl::device &device1, sycl::device &device2)
1000        {
1001            sycl::backend backend1 = device1.get_backend();
1002            sycl::backend backend2 = device2.get_backend();
1003            // levelzero backends always come first
1004            if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
1005            if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
1006            dpct::device_info prop1;
1007            dpct::get_device_info(prop1, device1);
1008            dpct::device_info prop2;
1009            dpct::get_device_info(prop2, device2);
1010            return prop1.get_max_compute_units() > prop2.get_max_compute_units();
1011        }
1012        static int convert_backend_index(std::string & backend) {
1013            if (backend == "ext_oneapi_level_zero:gpu") return 0;
1014            if (backend == "opencl:gpu") return 1;
1015            if (backend == "ext_oneapi_cuda:gpu") return 2;
1016            if (backend == "ext_oneapi_hip:gpu") return 3;
1017            if (backend == "opencl:cpu") return 4;
1018            if (backend == "opencl:acc") return 5;
1019            printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
1020            GGML_ABORT("fatal error");
1021        }
1022        static bool compare_backend(std::string &backend1, std::string &backend2) {
1023            return convert_backend_index(backend1) < convert_backend_index(backend2);
1024        }
1025        dev_mgr()
1026        {
1027            sycl::device default_device =
1028                sycl::device(sycl::default_selector_v);
1029            _devs.push_back(std::make_shared<device_ext>(default_device));
1030
1031            std::vector<sycl::device> sycl_all_devs;
1032            // Collect other devices except for the default device.
1033            if (default_device.is_cpu())
1034                _cpu_device = 0;
1035
1036            auto Platforms = sycl::platform::get_platforms();
1037            // Keep track of the number of devices per backend
1038            std::map<sycl::backend, size_t> DeviceNums;
1039            std::map<std::string, std::vector<sycl::device>> backend_devices;
1040            auto preferred_platform_name = get_preferred_gpu_platform_name();
1041
1042            while (!Platforms.empty()) {
1043                auto Platform = Platforms.back();
1044                Platforms.pop_back();
1045                auto platform_name = Platform.get_info<sycl::info::platform::name>();
1046                if (platform_name.compare(preferred_platform_name) != 0) {
1047                    continue;
1048                }
1049                auto devices = Platform.get_devices();
1050                std::string backend_type = get_device_backend_and_type(devices[0]);
1051                for (const auto &device : devices) {
1052                    backend_devices[backend_type].push_back(device);
1053                }
1054            }
1055
1056            std::vector<std::string> keys;
1057            for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
1058                keys.push_back(it->first);
1059            }
1060            std::sort(keys.begin(), keys.end(), compare_backend);
1061
1062            for (auto &key : keys) {
1063                std::vector<sycl::device> devs = backend_devices[key];
1064                std::sort(devs.begin(), devs.end(), compare_dev);
1065                for (const auto &dev : devs) {
1066                    sycl_all_devs.push_back(dev);
1067                }
1068            }
1069
1070            for (auto &dev : sycl_all_devs)
1071            {
1072                if (dev == default_device)
1073                {
1074                    continue;
1075                }
1076                _devs.push_back(std::make_shared<device_ext>(dev));
1077                if (_cpu_device == -1 && dev.is_cpu())
1078                {
1079                    _cpu_device = _devs.size() - 1;
1080                }
1081            }
1082        }
1083        void check_id(unsigned int id) const
1084        {
1085            if (id >= _devs.size())
1086            {
1087                throw std::runtime_error("invalid device id");
1088            }
1089        }
1090        std::vector<std::shared_ptr<device_ext>> _devs;
1091        /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
1092        /// thread id in _thread2dev_map, which means default device should be used
1093        /// for the current thread.
1094        const unsigned int DEFAULT_DEVICE_ID = 0;
1095        /// thread-id to device-id map.
1096        std::map<unsigned int, unsigned int> _thread2dev_map;
1097        int _cpu_device = -1;
1098    };
1099
1100    static inline sycl::queue &get_default_queue()
1101    {
1102        return dev_mgr::instance().current_device().default_queue();
1103    }
1104
1105    namespace detail
1106    {
1107        enum class pointer_access_attribute
1108        {
1109            host_only = 0,
1110            device_only,
1111            host_device,
1112            end
1113        };
1114
1115        static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
1116                                                              const void *ptr)
1117        {
1118            switch (sycl::get_pointer_type(ptr, q.get_context()))
1119            {
1120            case sycl::usm::alloc::unknown:
1121                return pointer_access_attribute::host_only;
1122            case sycl::usm::alloc::device:
1123                return pointer_access_attribute::device_only;
1124            case sycl::usm::alloc::shared:
1125            case sycl::usm::alloc::host:
1126                return pointer_access_attribute::host_device;
1127            }
1128        }
1129
1130        template <typename ArgT>
1131        inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
1132        {
1133            static_assert((unsigned char)library_data_t::library_data_t_size <=
1134                              std::numeric_limits<unsigned char>::max() &&
1135                          "library_data_t size exceeds limit.");
1136            static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
1137            return (std::uint64_t)Val;
1138        }
1139
1140        template <typename FirstT, typename... RestT>
1141        inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
1142                                                               RestT... RestVal)
1143        {
1144            static_assert((std::uint8_t)library_data_t::library_data_t_size <=
1145                              std::numeric_limits<unsigned char>::max() &&
1146                          "library_data_t size exceeds limit.");
1147            static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
1148            static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
1149            return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
1150        }
1151
1152        class mem_mgr
1153        {
1154            mem_mgr()
1155            {
1156                // Reserved address space, no real memory allocation happens here.
1157#if defined(__linux__)
1158                mapped_address_space =
1159                    (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
1160                                   MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
1161#elif defined(_WIN64)
1162                mapped_address_space = (byte_t *)VirtualAlloc(
1163                    NULL,               // NULL specified as the base address parameter
1164                    mapped_region_size, // Size of allocation
1165                    MEM_RESERVE,        // Allocate reserved pages
1166                    PAGE_NOACCESS);     // Protection = no access
1167#else
1168#error "Only support Windows and Linux."
1169#endif
1170                next_free = mapped_address_space;
1171            }
1172
1173        public:
1174            using buffer_id_t = int;
1175
1176            struct allocation
1177            {
1178                buffer_t buffer;
1179                byte_t *alloc_ptr;
1180                size_t size;
1181            };
1182
1183            ~mem_mgr()
1184            {
1185#if defined(__linux__)
1186                munmap(mapped_address_space, mapped_region_size);
1187#elif defined(_WIN64)
1188                VirtualFree(mapped_address_space, 0, MEM_RELEASE);
1189#else
1190#error "Only support Windows and Linux."
1191#endif
1192            }
1193
1194            mem_mgr(const mem_mgr &) = delete;
1195            mem_mgr &operator=(const mem_mgr &) = delete;
1196            mem_mgr(mem_mgr &&) = delete;
1197            mem_mgr &operator=(mem_mgr &&) = delete;
1198
1199            /// Allocate
1200            void *mem_alloc(size_t size)
1201            {
1202                if (!size)
1203                    return nullptr;
1204                std::lock_guard<std::mutex> lock(m_mutex);
1205                if (next_free + size > mapped_address_space + mapped_region_size)
1206                {
1207                    throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
1208                }
1209                // Allocation
1210                sycl::range<1> r(size);
1211                buffer_t buf(r);
1212                allocation A{buf, next_free, size};
1213                // Map allocation to device pointer
1214                void *result = next_free;
1215                m_map.emplace(next_free + size, A);
1216                // Update pointer to the next free space.
1217                next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
1218
1219                return result;
1220            }
1221
1222            /// Deallocate
1223            void mem_free(const void *ptr)
1224            {
1225                if (!ptr)
1226                    return;
1227                std::lock_guard<std::mutex> lock(m_mutex);
1228                auto it = get_map_iterator(ptr);
1229                m_map.erase(it);
1230            }
1231
1232            /// map: device pointer -> allocation(buffer, alloc_ptr, size)
1233            allocation translate_ptr(const void *ptr)
1234            {
1235                std::lock_guard<std::mutex> lock(m_mutex);
1236                auto it = get_map_iterator(ptr);
1237                return it->second;
1238            }
1239
1240            /// Check if the pointer represents device pointer or not.
1241            bool is_device_ptr(const void *ptr) const
1242            {
1243                std::lock_guard<std::mutex> lock(m_mutex);
1244                return (mapped_address_space <= ptr) &&
1245                       (ptr < mapped_address_space + mapped_region_size);
1246            }
1247
1248            /// Returns the instance of memory manager singleton.
1249            static mem_mgr &instance()
1250            {
1251                static mem_mgr m;
1252                return m;
1253            }
1254
1255        private:
1256            std::map<byte_t *, allocation> m_map;
1257            mutable std::mutex m_mutex;
1258            byte_t *mapped_address_space;
1259            byte_t *next_free;
1260            const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
1261            const size_t alignment = 256;
1262            /// This padding may be defined to some positive value to debug
1263            /// out of bound accesses.
1264            const size_t extra_padding = 0;
1265
1266            std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
1267            {
1268                auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
1269                if (it == m_map.end())
1270                {
1271                    // Not a virtual pointer.
1272                    throw std::runtime_error("can not get buffer from non-virtual pointer");
1273                }
1274                const allocation &alloc = it->second;
1275                if (ptr < alloc.alloc_ptr)
1276                {
1277                    // Out of bound.
1278                    // This may happen if there's a gap between allocations due to alignment
1279                    // or extra padding and pointer points to this gap.
1280                    throw std::runtime_error("invalid virtual pointer");
1281                }
1282                return it;
1283            }
1284        };
1285
1286        template <class T, memory_region Memory, size_t Dimension>
1287        class accessor;
1288        template <memory_region Memory, class T = byte_t>
1289        class memory_traits
1290        {
1291        public:
1292            static constexpr sycl::access::target target =
1293                sycl::access::target::device;
1294            static constexpr sycl::access_mode mode =
1295                (Memory == constant) ? sycl::access_mode::read
1296                                     : sycl::access_mode::read_write;
1297            static constexpr size_t type_size = sizeof(T);
1298            using element_t =
1299                typename std::conditional<Memory == constant, const T, T>::type;
1300            using value_t = typename std::remove_cv<T>::type;
1301            template <size_t Dimension = 1>
1302            using accessor_t = typename std::conditional<
1303                Memory == local, sycl::local_accessor<value_t, Dimension>,
1304                sycl::accessor<T, Dimension, mode, target>>::type;
1305            using pointer_t = T *;
1306        };
1307
1308        static inline void *dpct_malloc(size_t size, sycl::queue &q)
1309        {
1310            return sycl::malloc_device(size, q.get_device(), q.get_context());
1311        }
1312
1313#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
1314        static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
1315                                        sycl::queue &q)
1316        {
1317            pitch = PITCH_DEFAULT_ALIGN(x);
1318            return dpct_malloc(pitch * y * z, q);
1319        }
1320
1321        /**
1322         * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
1323         * @tparam valueT The type of the element to be set.
1324         * @param [in] q The queue in which the operation is done.
1325         * @param [in] dev_ptr Pointer to the virtual device memory address.
1326         * @param [in] value The value to be set.
1327         * @param [in] size Number of elements to be set to the value.
1328         * @return An event representing the memset operation.
1329         */
1330        template <typename valueT>
1331        static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
1332                                              valueT value, size_t size)
1333        {
1334            return q.fill(dev_ptr, value, size);
1335        }
1336
1337        /**
1338         * @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
1339         * @tparam valueT The type of the element to be set.
1340         * @param [in] q The queue in which the operation is done.
1341         * @param [in] data Pointer to the pitched device memory region.
1342         * @param [in] value The value to be set.
1343         * @param [in] size 3D memory region by number of elements.
1344         * @return An event list representing the memset operations.
1345         */
1346        template <typename valueT>
1347        static inline std::vector<sycl::event>
1348        dpct_memset(sycl::queue &q, pitched_data data, valueT value,
1349                    sycl::range<3> size)
1350        {
1351            std::vector<sycl::event> event_list;
1352            size_t slice = data.get_pitch() * data.get_y();
1353            unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
1354            for (size_t z = 0; z < size.get(2); ++z)
1355            {
1356                unsigned char *data_ptr = data_surface;
1357                for (size_t y = 0; y < size.get(1); ++y)
1358                {
1359                    event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
1360                    data_ptr += data.get_pitch();
1361                }
1362                data_surface += slice;
1363            }
1364            return event_list;
1365        }
1366
1367        /**
1368         * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
1369         * @tparam valueT The type of the element to be set.
1370         * @param [in] q The queue in which the operation is done.
1371         * @param [in] ptr Pointer to the virtual device memory.
1372         * @param [in] pitch The pitch size by number of elements, including padding.
1373         * @param [in] val The value to be set.
1374         * @param [in] x The width of memory region by number of elements.
1375         * @param [in] y The height of memory region by number of elements.
1376         * @return An event list representing the memset operations.
1377         */
1378        template <typename valueT>
1379        static inline std::vector<sycl::event>
1380        dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
1381                    size_t y)
1382        {
1383            return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
1384                               sycl::range<3>(x, y, 1));
1385        }
1386
1387        static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
1388                                                        const void *from_ptr,
1389                                                        memcpy_direction dir)
1390        {
1391            switch (dir)
1392            {
1393            case memcpy_direction::host_to_host:
1394            case memcpy_direction::host_to_device:
1395            case memcpy_direction::device_to_host:
1396            case memcpy_direction::device_to_device:
1397                return dir;
1398            case memcpy_direction::automatic:
1399            {
1400                // table[to_attribute][from_attribute]
1401                static const memcpy_direction
1402                    direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
1403                                   [static_cast<unsigned>(pointer_access_attribute::end)] =
1404                                       {{memcpy_direction::host_to_host,
1405                                         memcpy_direction::device_to_host,
1406                                         memcpy_direction::host_to_host},
1407                                        {memcpy_direction::host_to_device,
1408                                         memcpy_direction::device_to_device,
1409                                         memcpy_direction::device_to_device},
1410                                        {memcpy_direction::host_to_host,
1411                                         memcpy_direction::device_to_device,
1412                                         memcpy_direction::device_to_device}};
1413                return direction_table[static_cast<unsigned>(get_pointer_attribute(
1414                    q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
1415            }
1416            default:
1417                throw std::runtime_error("dpct_memcpy: invalid direction value");
1418            }
1419        }
1420
1421        static sycl::event
1422        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
1423                    memcpy_direction direction,
1424                    const std::vector<sycl::event> &dep_events = {})
1425        {
1426            if (!size)
1427                return sycl::event{};
1428            return q.memcpy(to_ptr, from_ptr, size, dep_events);
1429            GGML_UNUSED(direction);
1430        }
1431
1432        // Get actual copy range and make sure it will not exceed range.
1433        static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
1434                                            size_t pitch)
1435        {
1436            return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
1437        }
1438
1439        static inline size_t get_offset(sycl::id<3> id, size_t slice,
1440                                        size_t pitch)
1441        {
1442            return slice * id.get(2) + pitch * id.get(1) + id.get(0);
1443        }
1444
1445        /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
1446        /// and \p from_range to another specified by \p to_ptr and \p to_range.
1447        static inline std::vector<sycl::event>
1448        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1449                    sycl::range<3> to_range, sycl::range<3> from_range,
1450                    sycl::id<3> to_id, sycl::id<3> from_id,
1451                    sycl::range<3> size, memcpy_direction direction,
1452                    const std::vector<sycl::event> &dep_events = {})
1453        {
1454            // RAII for host pointer
1455            class host_buffer
1456            {
1457                void *_buf;
1458                size_t _size;
1459                sycl::queue &_q;
1460                const std::vector<sycl::event> &_deps; // free operation depends
1461
1462            public:
1463                host_buffer(size_t size, sycl::queue &q,
1464                            const std::vector<sycl::event> &deps)
1465                    : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
1466                void *get_ptr() const { return _buf; }
1467                size_t get_size() const { return _size; }
1468                ~host_buffer()
1469                {
1470                    if (_buf)
1471                    {
1472                        _q.submit([&](sycl::handler &cgh)
1473                                  {
1474        cgh.depends_on(_deps);
1475        cgh.host_task([buf = _buf] { std::free(buf); }); });
1476                    }
1477                }
1478            };
1479            std::vector<sycl::event> event_list;
1480
1481            size_t to_slice = to_range.get(1) * to_range.get(0),
1482                   from_slice = from_range.get(1) * from_range.get(0);
1483            unsigned char *to_surface =
1484                (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
1485            const unsigned char *from_surface =
1486                (const unsigned char *)from_ptr +
1487                get_offset(from_id, from_slice, from_range.get(0));
1488
1489            if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
1490            {
1491                return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
1492                                    direction, dep_events)};
1493            }
1494            direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
1495            size_t size_slice = size.get(1) * size.get(0);
1496            switch (direction)
1497            {
1498            case host_to_host:
1499                for (size_t z = 0; z < size.get(2); ++z)
1500                {
1501                    unsigned char *to_ptr = to_surface;
1502                    const unsigned char *from_ptr = from_surface;
1503                    if (to_range.get(0) == from_range.get(0) &&
1504                        to_range.get(0) == size.get(0))
1505                    {
1506                        event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
1507                                                         direction, dep_events));
1508                    }
1509                    else
1510                    {
1511                        for (size_t y = 0; y < size.get(1); ++y)
1512                        {
1513                            event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
1514                                                             direction, dep_events));
1515                            to_ptr += to_range.get(0);
1516                            from_ptr += from_range.get(0);
1517                        }
1518                    }
1519                    to_surface += to_slice;
1520                    from_surface += from_slice;
1521                }
1522                break;
1523            case host_to_device:
1524            {
1525                host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
1526                                event_list);
1527                std::vector<sycl::event> host_events;
1528                if (to_slice == size_slice)
1529                {
1530                    // Copy host data to a temp host buffer with the shape of target.
1531                    host_events =
1532                        dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
1533                                    sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
1534                                    host_to_host, dep_events);
1535                }
1536                else
1537                {
1538                    // Copy host data to a temp host buffer with the shape of target.
1539                    host_events = dpct_memcpy(
1540                        q, buf.get_ptr(), from_surface, to_range, from_range,
1541                        sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
1542                        // If has padding data, not sure whether it is useless. So fill temp
1543                        // buffer with it.
1544                        std::vector<sycl::event>{
1545                            dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
1546                                        device_to_host, dep_events)});
1547                }
1548                // Copy from temp host buffer to device with only one submit.
1549                event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
1550                                                 buf.get_size(), host_to_device,
1551                                                 host_events));
1552                break;
1553            }
1554            case device_to_host:
1555            {
1556                host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
1557                                event_list);
1558                // Copy from host temp buffer to host target with reshaping.
1559                event_list = dpct_memcpy(
1560                    q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
1561                    sycl::id<3>(0, 0, 0), size, host_to_host,
1562                    // Copy from device to temp host buffer with only one submit.
1563                    std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
1564                                                         buf.get_size(),
1565                                                         device_to_host, dep_events)});
1566                break;
1567            }
1568            case device_to_device:
1569                event_list.push_back(q.submit([&](sycl::handler &cgh){
1570                cgh.depends_on(dep_events);
1571                cgh.parallel_for<class dpct_memcpy_3d_detail>(
1572                    size,
1573                    [=](sycl::id<3> id) {
1574                        to_surface[get_offset(id, to_slice, to_range.get(0))] =
1575                            from_surface[get_offset(id, from_slice, from_range.get(0))];
1576                    }); }));
1577                break;
1578            default:
1579                throw std::runtime_error("dpct_memcpy: invalid direction value");
1580            }
1581            return event_list;
1582        }
1583
1584        /// memcpy 2D/3D matrix specified by pitched_data.
1585        static inline std::vector<sycl::event>
1586        dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
1587                    pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
1588                    memcpy_direction direction = automatic)
1589        {
1590            return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
1591                               sycl::range<3>(to.get_pitch(), to.get_y(), 1),
1592                               sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
1593                               size, direction);
1594        }
1595
1596        /// memcpy 2D matrix with pitch.
1597        static inline std::vector<sycl::event>
1598        dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1599                    size_t to_pitch, size_t from_pitch, size_t x, size_t y,
1600                    memcpy_direction direction = automatic)
1601        {
1602            return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
1603                               sycl::range<3>(from_pitch, y, 1),
1604                               sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
1605                               sycl::range<3>(x, y, 1), direction);
1606        }
1607
1608        namespace deprecated
1609        {
1610
1611            template <typename T, sycl::usm::alloc AllocKind>
1612            class usm_allocator
1613            {
1614            private:
1615                using Alloc = sycl::usm_allocator<T, AllocKind>;
1616                Alloc _impl;
1617
1618            public:
1619                using value_type = typename std::allocator_traits<Alloc>::value_type;
1620                using pointer = typename std::allocator_traits<Alloc>::pointer;
1621                using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
1622                using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
1623                using const_void_pointer =
1624                    typename std::allocator_traits<Alloc>::const_void_pointer;
1625                using reference = typename std::allocator_traits<Alloc>::value_type &;
1626                using const_reference =
1627                    const typename std::allocator_traits<Alloc>::value_type &;
1628                using difference_type =
1629                    typename std::allocator_traits<Alloc>::difference_type;
1630                using size_type = typename std::allocator_traits<Alloc>::size_type;
1631                using propagate_on_container_copy_assignment = typename std::allocator_traits<
1632                    Alloc>::propagate_on_container_copy_assignment;
1633                using propagate_on_container_move_assignment = typename std::allocator_traits<
1634                    Alloc>::propagate_on_container_move_assignment;
1635                using propagate_on_container_swap =
1636                    typename std::allocator_traits<Alloc>::propagate_on_container_swap;
1637                using is_always_equal =
1638                    typename std::allocator_traits<Alloc>::is_always_equal;
1639
1640                template <typename U>
1641                struct rebind
1642                {
1643                    typedef usm_allocator<U, AllocKind> other;
1644                };
1645
1646                usm_allocator() : _impl(dpct::get_default_queue()) {}
1647                ~usm_allocator() {}
1648                usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
1649                usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
1650                pointer address(reference r) { return &r; }
1651                const_pointer address(const_reference r) { return &r; }
1652                pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
1653                {
1654                    return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
1655                }
1656                void deallocate(pointer p, size_type cnt)
1657                {
1658                    std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
1659                }
1660                size_type max_size() const
1661                {
1662                    return std::allocator_traits<Alloc>::max_size(_impl);
1663                }
1664                bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
1665                bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
1666            };
1667
1668        } // namespace deprecated
1669
1670        inline void dpct_free(void *ptr,
1671                              const sycl::queue &q)
1672        {
1673            if (ptr)
1674            {
1675                sycl::free(ptr, q.get_context());
1676            }
1677        }
1678
1679        template <typename T>
1680        inline auto get_memory(const void *x)
1681        {
1682            T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
1683            return new_x;
1684        }
1685
1686        template <typename T>
1687        inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
1688        {
1689            using Ty = typename DataType<T>::T2;
1690            Ty s_h;
1691            if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
1692                detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
1693                    .wait();
1694            else
1695                s_h = *reinterpret_cast<const Ty *>(s);
1696            return s_h;
1697        }
1698
1699    } // namespace detail
1700
1701    template <typename T>
1702    inline auto get_value(const T *s, sycl::queue &q)
1703    {
1704        return detail::get_value(s, q);
1705    }
1706
1707    namespace detail
1708    {
1709    template <class Ta, class Tb, class Tc, class Ts>
1710    inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
1711                          int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1712                          const void * beta, void * c, int ldc) {
1713        Ts   alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1714        Ts   beta_value  = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1715        auto data_a      = get_memory<const Ta>(a);
1716        auto data_b      = get_memory<const Tb>(b);
1717        auto data_c      = get_memory<Tc>(c);
1718        oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
1719                                               lda, data_b, ldb, beta_value, data_c, ldc);
1720    }
1721
1722        template <typename VecT, class BinaryOperation, class = void>
1723        class vectorized_binary
1724        {
1725        public:
1726            inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1727            {
1728                VecT v4;
1729                for (size_t i = 0; i < v4.size(); ++i)
1730                {
1731                    v4[i] = binary_op(a[i], b[i]);
1732                }
1733                return v4;
1734            }
1735        };
1736
1737        template <typename VecT, class BinaryOperation>
1738        class vectorized_binary<
1739            VecT, BinaryOperation,
1740            std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
1741        {
1742        public:
1743            inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1744            {
1745                return binary_op(a, b).template as<VecT>();
1746            }
1747        };
1748
1749        template <class Ta, class Tb, class Tc, class Ts>
1750        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1751                                    int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1752                                    int ldb, const void * beta, void ** c, int ldc, int batch_size,
1753                                    matrix_info_t<float> * matrix_info) {
1754            Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1755            Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1756
1757            matrix_info->transpose_info[0] = a_trans;
1758            matrix_info->transpose_info[1] = b_trans;
1759            matrix_info->value_info[0] = alpha_value;
1760            matrix_info->value_info[1] = beta_value;
1761            matrix_info->size_info[0] = m;
1762            matrix_info->size_info[1] = n;
1763            matrix_info->size_info[2] = k;
1764            matrix_info->ld_info[0] = lda;
1765            matrix_info->ld_info[1] = ldb;
1766            matrix_info->ld_info[2] = ldc;
1767            matrix_info->groupsize_info = batch_size;
1768
1769            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1770                q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1771                matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1772                reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1773                reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1774                reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1775                matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1776        }
1777
1778        template <class Ta, class Tb, class Tc, class Ts>
1779        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1780                                    int m, int n, int k, const void * alpha, const void * a, int lda,
1781                                    long long int stride_a, const void * b, int ldb, long long int stride_b,
1782                                    const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
1783            Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1784            Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1785            auto data_a = get_memory<const Ta>(a);
1786            auto data_b = get_memory<const Tb>(b);
1787            auto data_c = get_memory<Tc>(c);
1788            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
1789                                                         data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1790                                                         data_c, ldc, stride_c, batch_size);
1791        }
1792
1793    } // namespace detail
1794
1795    template <typename VecT, class BinaryOperation>
1796    inline unsigned vectorized_binary(unsigned a, unsigned b,
1797                                      const BinaryOperation binary_op)
1798    {
1799        sycl::vec<unsigned, 1> v0{a}, v1{b};
1800        auto v2 = v0.as<VecT>();
1801        auto v3 = v1.as<VecT>();
1802        auto v4 =
1803            detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1804        v0 = v4.template as<sycl::vec<unsigned, 1>>();
1805        return v0;
1806    }
1807
1808    static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
1809                                  memcpy_direction direction = automatic,
1810                                  sycl::queue &q = dpct::get_default_queue())
1811    {
1812        detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
1813    }
1814
1815    static inline unsigned int select_device(unsigned int id)
1816    {
1817        dev_mgr::instance().select_device(id);
1818        return id;
1819    }
1820
1821    template <typename T>
1822    T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
1823                               unsigned int logical_sub_group_size = 32)
1824    {
1825        unsigned int id = g.get_local_linear_id();
1826        unsigned int start_index =
1827            id / logical_sub_group_size * logical_sub_group_size;
1828        unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
1829        return sycl::select_from_group(g, x,
1830                                       target_offset < logical_sub_group_size
1831                                           ? start_index + target_offset
1832                                           : id);
1833    }
1834
1835    template <typename T1, typename T2>
1836    using dot_product_acc_t = std::conditional_t<
1837        std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1838        uint32_t,
1839        int32_t>;
1840
1841    template <typename T>
1842    sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
1843      return sycl::vec<T, 1>(val)
1844          .template as<sycl::vec<
1845              std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
1846              4>>()
1847          .template convert<T>();
1848    }
1849
1850    template <typename T1, typename T2, typename T3>
1851    inline auto dp4a(T1 a, T2 b, T3 c) {
1852      dot_product_acc_t<T1, T2> res = c;
1853      auto va = extract_and_sign_or_zero_extend4(a);
1854      auto vb = extract_and_sign_or_zero_extend4(b);
1855      res += va[0] * vb[0];
1856      res += va[1] * vb[1];
1857      res += va[2] * vb[2];
1858      res += va[3] * vb[3];
1859      return res;
1860    }
1861
1862    struct sub_sat
1863    {
1864        template <typename T>
1865        auto operator()(const T x, const T y) const
1866        {
1867            return sycl::sub_sat(x, y);
1868        }
1869    };
1870
1871    template <typename S, typename T>
1872    inline T vectorized_min(T a, T b)
1873    {
1874        sycl::vec<T, 1> v0{a}, v1{b};
1875        auto v2 = v0.template as<S>();
1876        auto v3 = v1.template as<S>();
1877        auto v4 = sycl::min(v2, v3);
1878        v0 = v4.template as<sycl::vec<T, 1>>();
1879        return v0;
1880    }
1881
1882    inline float pow(const float a, const int b) { return sycl::pown(a, b); }
1883    inline double pow(const double a, const int b) { return sycl::pown(a, b); }
1884    inline float pow(const float a, const float b) { return sycl::pow(a, b); }
1885    inline double pow(const double a, const double b) { return sycl::pow(a, b); }
1886    template <typename T, typename U>
1887    inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
1888    pow(const T a, const U b)
1889    {
1890        return sycl::pow(a, static_cast<T>(b));
1891    }
1892    template <typename T, typename U>
1893    inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
1894    pow(const T a, const U b)
1895    {
1896        return sycl::pow(static_cast<double>(a), static_cast<double>(b));
1897    }
1898
1899    inline double min(const double a, const float b)
1900    {
1901        return sycl::fmin(a, static_cast<double>(b));
1902    }
1903    inline double min(const float a, const double b)
1904    {
1905        return sycl::fmin(static_cast<double>(a), b);
1906    }
1907    inline float min(const float a, const float b) { return sycl::fmin(a, b); }
1908    inline double min(const double a, const double b) { return sycl::fmin(a, b); }
1909    inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
1910    {
1911        return sycl::min(a, static_cast<std::uint32_t>(b));
1912    }
1913    inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
1914    {
1915        return sycl::min(static_cast<std::uint32_t>(a), b);
1916    }
1917    inline std::int32_t min(const std::int32_t a, const std::int32_t b)
1918    {
1919        return sycl::min(a, b);
1920    }
1921    inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
1922    {
1923        return sycl::min(a, b);
1924    }
1925    inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
1926    {
1927        return sycl::min(a, static_cast<std::uint64_t>(b));
1928    }
1929    inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
1930    {
1931        return sycl::min(static_cast<std::uint64_t>(a), b);
1932    }
1933    inline std::int64_t min(const std::int64_t a, const std::int64_t b)
1934    {
1935        return sycl::min(a, b);
1936    }
1937    inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
1938    {
1939        return sycl::min(a, b);
1940    }
1941    inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
1942    {
1943        return sycl::min(a, static_cast<std::uint64_t>(b));
1944    }
1945    inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
1946    {
1947        return sycl::min(static_cast<std::uint64_t>(a), b);
1948    }
1949    inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
1950    {
1951        return sycl::min(a, static_cast<std::uint64_t>(b));
1952    }
1953    inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
1954    {
1955        return sycl::min(static_cast<std::uint64_t>(a), b);
1956    }
1957    // max function overloads.
1958    // For floating-point types, `float` or `double` arguments are acceptable.
1959    // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
1960    // `std::int64_t` type arguments are acceptable.
1961    inline double max(const double a, const float b)
1962    {
1963        return sycl::fmax(a, static_cast<double>(b));
1964    }
1965    inline double max(const float a, const double b)
1966    {
1967        return sycl::fmax(static_cast<double>(a), b);
1968    }
1969    inline float max(const float a, const float b) { return sycl::fmax(a, b); }
1970    inline double max(const double a, const double b) { return sycl::fmax(a, b); }
1971    inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
1972    {
1973        return sycl::max(a, static_cast<std::uint32_t>(b));
1974    }
1975    inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
1976    {
1977        return sycl::max(static_cast<std::uint32_t>(a), b);
1978    }
1979    inline std::int32_t max(const std::int32_t a, const std::int32_t b)
1980    {
1981        return sycl::max(a, b);
1982    }
1983    inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
1984    {
1985        return sycl::max(a, b);
1986    }
1987    inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
1988    {
1989        return sycl::max(a, static_cast<std::uint64_t>(b));
1990    }
1991    inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
1992    {
1993        return sycl::max(static_cast<std::uint64_t>(a), b);
1994    }
1995    inline std::int64_t max(const std::int64_t a, const std::int64_t b)
1996    {
1997        return sycl::max(a, b);
1998    }
1999    inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
2000    {
2001        return sycl::max(a, b);
2002    }
2003    inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
2004    {
2005        return sycl::max(a, static_cast<std::uint64_t>(b));
2006    }
2007    inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
2008    {
2009        return sycl::max(static_cast<std::uint64_t>(a), b);
2010    }
2011    inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
2012    {
2013        return sycl::max(a, static_cast<std::uint64_t>(b));
2014    }
2015    inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
2016    {
2017        return sycl::max(static_cast<std::uint64_t>(a), b);
2018    }
2019
2020    inline void
2021    has_capability_or_fail(const sycl::device &dev,
2022                           const std::initializer_list<sycl::aspect> &props)
2023    {
2024        for (const auto &it : props)
2025        {
2026            if (dev.has(it))
2027                continue;
2028            switch (it)
2029            {
2030            case sycl::aspect::fp64:
2031                throw std::runtime_error("'double' is not supported in '" +
2032                                         dev.get_info<sycl::info::device::name>() +
2033                                         "' device");
2034                break;
2035            case sycl::aspect::fp16:
2036                throw std::runtime_error("'half' is not supported in '" +
2037                                         dev.get_info<sycl::info::device::name>() +
2038                                         "' device");
2039                break;
2040            default:
2041#define __SYCL_ASPECT(ASPECT, ID) \
2042    case sycl::aspect::ASPECT:    \
2043        return #ASPECT;
2044#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
2045#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
2046                auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
2047                {
2048                    switch (AspectNum)
2049                    {
2050#include <sycl/info/aspects.def>
2051#include <sycl/info/aspects_deprecated.def>
2052                    default:
2053                        return "unknown aspect";
2054                    }
2055                };
2056#undef __SYCL_ASPECT_DEPRECATED_ALIAS
2057#undef __SYCL_ASPECT_DEPRECATED
2058#undef __SYCL_ASPECT
2059                throw std::runtime_error(
2060                    "'" + getAspectNameStr(it) + "' is not supported in '" +
2061                    dev.get_info<sycl::info::device::name>() + "' device");
2062            }
2063            break;
2064        }
2065    }
2066
2067    static inline unsigned int get_current_device_id()
2068    {
2069        return dev_mgr::instance().current_device_id();
2070    }
2071
2072    static inline device_ext &get_current_device()
2073    {
2074        return dev_mgr::instance().current_device();
2075    }
2076
2077    static inline device_ext &get_device(unsigned int id)
2078    {
2079        return dev_mgr::instance().get_device(id);
2080    }
2081
2082    static inline sycl::queue &get_in_order_queue()
2083    {
2084        return dev_mgr::instance().current_device().in_order_queue();
2085    }
2086
2087    static sycl::event
2088    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
2089                memcpy_direction direction,
2090                const std::vector<sycl::event> &dep_events = {})
2091    {
2092        if (!size)
2093            return sycl::event{};
2094        return q.memcpy(to_ptr, from_ptr, size, dep_events);
2095        GGML_UNUSED(direction);
2096    }
2097
2098    // Get actual copy range and make sure it will not exceed range.
2099    static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
2100                                        size_t pitch)
2101    {
2102        return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
2103    }
2104
2105    static inline size_t get_offset(sycl::id<3> id, size_t slice,
2106                                    size_t pitch)
2107    {
2108        return slice * id.get(2) + pitch * id.get(1) + id.get(0);
2109    }
2110
2111    /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
2112    /// and \p from_range to another specified by \p to_ptr and \p to_range.
2113    static inline std::vector<sycl::event>
2114    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2115                sycl::range<3> to_range, sycl::range<3> from_range,
2116                sycl::id<3> to_id, sycl::id<3> from_id,
2117                sycl::range<3> size, memcpy_direction direction,
2118                const std::vector<sycl::event> &dep_events = {})
2119    {
2120        // RAII for host pointer
2121        class host_buffer
2122        {
2123            void *_buf;
2124            size_t _size;
2125            sycl::queue &_q;
2126            const std::vector<sycl::event> &_deps; // free operation depends
2127
2128        public:
2129            host_buffer(size_t size, sycl::queue &q,
2130                        const std::vector<sycl::event> &deps)
2131                : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
2132            void *get_ptr() const { return _buf; }
2133            size_t get_size() const { return _size; }
2134            ~host_buffer()
2135            {
2136                if (_buf)
2137                {
2138                    _q.submit([&](sycl::handler &cgh)
2139                              {
2140            cgh.depends_on(_deps);
2141            cgh.host_task([buf = _buf] { std::free(buf); }); });
2142                }
2143            }
2144        };
2145        std::vector<sycl::event> event_list;
2146
2147        size_t to_slice = to_range.get(1) * to_range.get(0),
2148               from_slice = from_range.get(1) * from_range.get(0);
2149        unsigned char *to_surface =
2150            (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
2151        const unsigned char *from_surface =
2152            (const unsigned char *)from_ptr +
2153            get_offset(from_id, from_slice, from_range.get(0));
2154
2155        if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
2156        {
2157            return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
2158                                direction, dep_events)};
2159        }
2160        direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
2161        size_t size_slice = size.get(1) * size.get(0);
2162        switch (direction)
2163        {
2164        case host_to_host:
2165            for (size_t z = 0; z < size.get(2); ++z)
2166            {
2167                unsigned char *to_ptr = to_surface;
2168                const unsigned char *from_ptr = from_surface;
2169                if (to_range.get(0) == from_range.get(0) &&
2170                    to_range.get(0) == size.get(0))
2171                {
2172                    event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
2173                                                     direction, dep_events));
2174                }
2175                else
2176                {
2177                    for (size_t y = 0; y < size.get(1); ++y)
2178                    {
2179                        event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
2180                                                         direction, dep_events));
2181                        to_ptr += to_range.get(0);
2182                        from_ptr += from_range.get(0);
2183                    }
2184                }
2185                to_surface += to_slice;
2186                from_surface += from_slice;
2187            }
2188            break;
2189        case host_to_device:
2190        {
2191            host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
2192                            event_list);
2193            std::vector<sycl::event> host_events;
2194            if (to_slice == size_slice)
2195            {
2196                // Copy host data to a temp host buffer with the shape of target.
2197                host_events =
2198                    dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
2199                                sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
2200                                host_to_host, dep_events);
2201            }
2202            else
2203            {
2204                // Copy host data to a temp host buffer with the shape of target.
2205                host_events = dpct_memcpy(
2206                    q, buf.get_ptr(), from_surface, to_range, from_range,
2207                    sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
2208                    // If has padding data, not sure whether it is useless. So fill temp
2209                    // buffer with it.
2210                    std::vector<sycl::event>{
2211                        dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
2212                                    device_to_host, dep_events)});
2213            }
2214            // Copy from temp host buffer to device with only one submit.
2215            event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
2216                                             buf.get_size(), host_to_device,
2217                                             host_events));
2218            break;
2219        }
2220        case device_to_host:
2221        {
2222            host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
2223                            event_list);
2224            // Copy from host temp buffer to host target with reshaping.
2225            event_list = dpct_memcpy(
2226                q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
2227                sycl::id<3>(0, 0, 0), size, host_to_host,
2228                // Copy from device to temp host buffer with only one submit.
2229                std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
2230                                                     buf.get_size(),
2231                                                     device_to_host, dep_events)});
2232            break;
2233        }
2234        case device_to_device:
2235            event_list.push_back(q.submit([&](sycl::handler &cgh)
2236                                          {
2237        cgh.depends_on(dep_events);
2238        cgh.parallel_for<class dpct_memcpy_3d_detail>(
2239            size,
2240            [=](sycl::id<3> id) {
2241                to_surface[get_offset(id, to_slice, to_range.get(0))] =
2242                    from_surface[get_offset(id, from_slice, from_range.get(0))];
2243            }); }));
2244        break;
2245        default:
2246            throw std::runtime_error("dpct_memcpy: invalid direction value");
2247        }
2248        return event_list;
2249    }
2250
2251    /// memcpy 2D/3D matrix specified by pitched_data.
2252    static inline std::vector<sycl::event>
2253    dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
2254                pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
2255                memcpy_direction direction = automatic)
2256    {
2257        return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
2258                           sycl::range<3>(to.get_pitch(), to.get_y(), 1),
2259                           sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
2260                           size, direction);
2261    }
2262
2263    /// memcpy 2D matrix with pitch.
2264    static inline std::vector<sycl::event>
2265    dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2266                size_t to_pitch, size_t from_pitch, size_t x, size_t y,
2267                memcpy_direction direction = automatic)
2268    {
2269        return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
2270                           sycl::range<3>(from_pitch, y, 1),
2271                           sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
2272                           sycl::range<3>(x, y, 1), direction);
2273    }
2274
2275    inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
2276                     int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2277                     library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2278                     library_data_t scaling_type) {
2279        if (scaling_type == library_data_t::real_float &&
2280            c_type == library_data_t::complex_float)
2281        {
2282            scaling_type = library_data_t::complex_float;
2283        }
2284        else if (scaling_type == library_data_t::real_double &&
2285                 c_type == library_data_t::complex_double)
2286        {
2287            scaling_type = library_data_t::complex_double;
2288        }
2289
2290        std::uint64_t key =
2291            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2292        switch (key)
2293        {
2294        case detail::get_type_combination_id(
2295            library_data_t::real_float, library_data_t::real_float,
2296            library_data_t::real_float, library_data_t::real_float):
2297        {
2298            detail::gemm_impl<float, float, float, float>(
2299                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2300            break;
2301        }
2302        case detail::get_type_combination_id(
2303            library_data_t::real_double, library_data_t::real_double,
2304            library_data_t::real_double, library_data_t::real_double):
2305        {
2306            detail::gemm_impl<double, double, double, double>(
2307                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2308            break;
2309        }
2310        case detail::get_type_combination_id(
2311            library_data_t::complex_float, library_data_t::complex_float,
2312            library_data_t::complex_float, library_data_t::complex_float):
2313        {
2314            detail::gemm_impl<std::complex<float>, std::complex<float>,
2315                              std::complex<float>, std::complex<float>>(
2316                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2317            break;
2318        }
2319        case detail::get_type_combination_id(
2320            library_data_t::complex_double, library_data_t::complex_double,
2321            library_data_t::complex_double, library_data_t::complex_double):
2322        {
2323            detail::gemm_impl<std::complex<double>, std::complex<double>,
2324                              std::complex<double>, std::complex<double>>(
2325                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2326            break;
2327        }
2328        case detail::get_type_combination_id(
2329            library_data_t::real_half, library_data_t::real_half,
2330            library_data_t::real_half, library_data_t::real_half):
2331        {
2332            detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2333                              sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
2334                                          lda, b, ldb, beta, c, ldc);
2335            break;
2336        }
2337#ifdef __INTEL_MKL__
2338        case detail::get_type_combination_id(
2339            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2340            library_data_t::real_float, library_data_t::real_float):
2341        {
2342            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2343                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2344            break;
2345        }
2346        case detail::get_type_combination_id(
2347            library_data_t::real_half, library_data_t::real_half,
2348            library_data_t::real_float, library_data_t::real_float):
2349        {
2350            detail::gemm_impl<sycl::half, sycl::half, float, float>(
2351                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2352            break;
2353        }
2354        case detail::get_type_combination_id(
2355            library_data_t::real_half, library_data_t::real_half,
2356            library_data_t::real_half, library_data_t::real_float):
2357        {
2358            float alpha_value =
2359                dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2360            float beta_value =
2361                dpct::get_value(reinterpret_cast<const float *>(beta), q);
2362            sycl::half alpha_half(alpha_value);
2363            sycl::half beta_half(beta_value);
2364            detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2365                              sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
2366                                          a, lda, b, ldb, &beta_half, c, ldc);
2367            break;
2368        }
2369        case detail::get_type_combination_id(
2370            library_data_t::real_int8, library_data_t::real_int8,
2371            library_data_t::real_float, library_data_t::real_float):
2372        {
2373            detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
2374                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2375            break;
2376        }
2377        case detail::get_type_combination_id(
2378            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2379            library_data_t::real_bfloat16, library_data_t::real_float):
2380        {
2381            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2382                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2383            break;
2384        }
2385        case detail::get_type_combination_id(
2386            library_data_t::real_int8, library_data_t::real_int8,
2387            library_data_t::real_int32, library_data_t::real_int32):
2388        {
2389            float alpha_float =
2390                dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2391            float beta_float =
2392                dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2393            detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2394                q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2395            break;
2396        }
2397#endif // __INTEL_MKL__
2398        default:
2399            throw std::runtime_error("the combination of data type is unsupported");
2400        }
2401    }  // gemm()
2402
2403    /// Computes a batch of matrix-matrix product with general matrices.
2404    /// \param [in] q The queue where the routine should be executed.
2405    /// \param [in] a_trans Specifies the operation applied to A.
2406    /// \param [in] b_trans Specifies the operation applied to B.
2407    /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2408    /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2409    /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2410    /// \param [in] alpha Scaling factor for the matrix-matrix product.
2411    /// \param [in] a Input matrix A.
2412    /// \param [in] a_type Data type of the matrix A.
2413    /// \param [in] lda Leading dimension of A.
2414    /// \param [in] b Input matrix B.
2415    /// \param [in] b_type Data type of the matrix B.
2416    /// \param [in] ldb Leading dimension of B.
2417    /// \param [in] beta Scaling factor for matrix C.
2418    /// \param [in, out] c Input/Output matrix C.
2419    /// \param [in] c_type Data type of the matrix C.
2420    /// \param [in] ldc Leading dimension of C.
2421    /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2422    /// \param [in] scaling_type Data type of the scaling factors.
2423    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2424                           int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2425                           const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2426                           library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2427                           matrix_info_t<float> * matrix_info) {
2428        std::uint64_t key =
2429            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2430        switch (key)
2431        {
2432        case detail::get_type_combination_id(
2433            library_data_t::real_float, library_data_t::real_float,
2434            library_data_t::real_float, library_data_t::real_float):
2435        {
2436            detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437                                                                beta, c, ldc, batch_size, matrix_info);
2438            break;
2439        }
2440        case detail::get_type_combination_id(
2441            library_data_t::real_double, library_data_t::real_double,
2442            library_data_t::real_double, library_data_t::real_double):
2443        {
2444            detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2445                                                                    beta, c, ldc, batch_size, matrix_info);
2446            break;
2447        }
2448        case detail::get_type_combination_id(
2449            library_data_t::real_half, library_data_t::real_half,
2450            library_data_t::real_half, library_data_t::real_half):
2451        {
2452            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2453                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2454            break;
2455        }
2456#ifdef __INTEL_MKL__
2457        case detail::get_type_combination_id(
2458            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2459            library_data_t::real_bfloat16, library_data_t::real_float):
2460        {
2461            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2462                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2463            break;
2464        }
2465        case detail::get_type_combination_id(
2466            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2467            library_data_t::real_float, library_data_t::real_float):
2468        {
2469            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2470                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2471            break;
2472        }
2473#endif
2474        case detail::get_type_combination_id(
2475            library_data_t::real_int8, library_data_t::real_int8,
2476            library_data_t::real_int32, library_data_t::real_int32):
2477        {
2478            float alpha_float =
2479                dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2480            float beta_float =
2481                dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2482            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2483                q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2484                matrix_info);
2485            break;
2486        }
2487        case detail::get_type_combination_id(
2488            library_data_t::real_int8, library_data_t::real_int8,
2489            library_data_t::real_float, library_data_t::real_float):
2490        {
2491            detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2492                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2493            break;
2494        }
2495        case detail::get_type_combination_id(
2496            library_data_t::real_half, library_data_t::real_half,
2497            library_data_t::real_float, library_data_t::real_float):
2498        {
2499            detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2500                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2501            break;
2502        }
2503        case detail::get_type_combination_id(
2504            library_data_t::real_half, library_data_t::real_half,
2505            library_data_t::real_half, library_data_t::real_float):
2506        {
2507            float alpha_value =
2508                dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2509            float beta_value =
2510                dpct::get_value(reinterpret_cast<const float *>(beta), q);
2511            sycl::half alpha_half(alpha_value);
2512            sycl::half beta_half(beta_value);
2513            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2514                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
2515            break;
2516        }
2517        default:
2518            throw std::runtime_error("the combination of data type is unsupported");
2519        }
2520    }
2521
2522    /// Computes a batch of matrix-matrix product with general matrices.
2523    /// \param [in] q The queue where the routine should be executed.
2524    /// \param [in] a_trans Specifies the operation applied to A.
2525    /// \param [in] b_trans Specifies the operation applied to B.
2526    /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2527    /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2528    /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2529    /// \param [in] alpha Scaling factor for the matrix-matrix product.
2530    /// \param [in] a Input matrix A.
2531    /// \param [in] a_type Data type of the matrix A.
2532    /// \param [in] lda Leading dimension of A.
2533    /// \param [in] stride_a Stride between the different A matrices.
2534    /// \param [in] b Input matrix B.
2535    /// \param [in] b_type Data type of the matrix B.
2536    /// \param [in] ldb Leading dimension of B.
2537    /// \param [in] stride_b Stride between the different B matrices.
2538    /// \param [in] beta Scaling factor for matrix C.
2539    /// \param [in, out] c Input/Output matrix C.
2540    /// \param [in] c_type Data type of the matrix C.
2541    /// \param [in] ldc Leading dimension of C.
2542    /// \param [in] stride_c Stride between the different C matrices.
2543    /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2544    /// \param [in] scaling_type Data type of the scaling factors.
2545    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2546                           int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2547                           long long int stride_a, const void * b, library_data_t b_type, int ldb,
2548                           long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2549                           long long int stride_c, int batch_size, library_data_t scaling_type) {
2550        if (scaling_type == library_data_t::real_float &&
2551            c_type == library_data_t::complex_float)
2552        {
2553            scaling_type = library_data_t::complex_float;
2554        }
2555        else if (scaling_type == library_data_t::real_double &&
2556                 c_type == library_data_t::complex_double)
2557        {
2558            scaling_type = library_data_t::complex_double;
2559        }
2560
2561        std::uint64_t key =
2562            detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2563        switch (key)
2564        {
2565        case detail::get_type_combination_id(
2566            library_data_t::real_float, library_data_t::real_float,
2567            library_data_t::real_float, library_data_t::real_float):
2568        {
2569            detail::gemm_batch_impl<float, float, float, float>(
2570                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2571                beta, c, ldc, stride_c, batch_size);
2572            break;
2573        }
2574        case detail::get_type_combination_id(
2575            library_data_t::real_double, library_data_t::real_double,
2576            library_data_t::real_double, library_data_t::real_double):
2577        {
2578            detail::gemm_batch_impl<double, double, double, double>(
2579                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2580                beta, c, ldc, stride_c, batch_size);
2581            break;
2582        }
2583        case detail::get_type_combination_id(
2584            library_data_t::complex_float, library_data_t::complex_float,
2585            library_data_t::complex_float, library_data_t::complex_float):
2586        {
2587            detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2588                                    std::complex<float>, std::complex<float>>(
2589                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2590                beta, c, ldc, stride_c, batch_size);
2591            break;
2592        }
2593        case detail::get_type_combination_id(
2594            library_data_t::complex_double, library_data_t::complex_double,
2595            library_data_t::complex_double, library_data_t::complex_double):
2596        {
2597            detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2598                                    std::complex<double>, std::complex<double>>(
2599                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2600                beta, c, ldc, stride_c, batch_size);
2601            break;
2602        }
2603        case detail::get_type_combination_id(
2604            library_data_t::real_half, library_data_t::real_half,
2605            library_data_t::real_half, library_data_t::real_half):
2606        {
2607            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2608                                    sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2609                                                a, lda, stride_a, b, ldb, stride_b,
2610                                                beta, c, ldc, stride_c, batch_size);
2611            break;
2612        }
2613#ifdef __INTEL_MKL__
2614        case detail::get_type_combination_id(
2615            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2616            library_data_t::real_bfloat16, library_data_t::real_float):
2617        {
2618            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2619                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2620                batch_size);
2621            break;
2622        }
2623        case detail::get_type_combination_id(
2624            library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2625            library_data_t::real_float, library_data_t::real_float):
2626        {
2627            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2628                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2629                batch_size);
2630            break;
2631        }
2632#endif
2633        case detail::get_type_combination_id(
2634            library_data_t::real_int8, library_data_t::real_int8,
2635            library_data_t::real_int32, library_data_t::real_int32):
2636        {
2637            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2638                                    std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
2639                                                  a, lda, stride_a, b, ldb, stride_b,
2640                                                  beta, c, ldc, stride_c, batch_size);
2641            break;
2642        }
2643        case detail::get_type_combination_id(
2644            library_data_t::real_int8, library_data_t::real_int8,
2645            library_data_t::real_float, library_data_t::real_float):
2646        {
2647            detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2648                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2649                beta, c, ldc, stride_c, batch_size);
2650            break;
2651        }
2652        case detail::get_type_combination_id(
2653            library_data_t::real_half, library_data_t::real_half,
2654            library_data_t::real_float, library_data_t::real_float):
2655        {
2656            detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2657                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2658                beta, c, ldc, stride_c, batch_size);
2659            break;
2660        }
2661        case detail::get_type_combination_id(
2662            library_data_t::real_half, library_data_t::real_half,
2663            library_data_t::real_half, library_data_t::real_float):
2664        {
2665            float alpha_value =
2666                dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2667            float beta_value =
2668                dpct::get_value(reinterpret_cast<const float *>(beta), q);
2669            sycl::half alpha_half(alpha_value);
2670            sycl::half beta_half(beta_value);
2671            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2672                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
2673                &beta_half, c, ldc, stride_c, batch_size);
2674            break;
2675        }
2676        default:
2677            throw std::runtime_error("the combination of data type is unsupported");
2678        }
2679    }
2680
2681    static inline void
2682    async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
2683                      size_t from_pitch, size_t x, size_t y,
2684                      memcpy_direction direction = automatic,
2685                      sycl::queue &q = get_default_queue())
2686    {
2687        detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
2688                            direction);
2689    }
2690
2691    using err0 = detail::generic_error_type<struct err0_tag, int>;
2692    using err1 = detail::generic_error_type<struct err1_tag, int>;
2693
2694    static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
2695        detail::dpct_free(ptr, q);
2696    }
2697
2698    /// dpct accessor used as device function parameter.
2699    template <class T, memory_region Memory, size_t Dimension> class accessor;
2700    template <class T, memory_region Memory> class accessor<T, Memory, 3> {
2701    public:
2702        using memory_t = detail::memory_traits<Memory, T>;
2703        using element_t = typename memory_t::element_t;
2704        using pointer_t = typename memory_t::pointer_t;
2705        using accessor_t = typename memory_t::template accessor_t<3>;
2706        accessor(pointer_t data, const sycl::range<3> &in_range)
2707            : _data(data), _range(in_range) {}
2708        template <memory_region M = Memory>
2709        accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2710            : accessor(acc, acc.get_range()) {}
2711        accessor(const accessor_t &acc, const sycl::range<3> &in_range)
2712            : accessor(acc.get_pointer(), in_range) {}
2713        accessor<T, Memory, 2> operator[](size_t index) const {
2714            sycl::range<2> sub(_range.get(1), _range.get(2));
2715            return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
2716        }
2717
2718        pointer_t get_ptr() const { return _data; }
2719
2720    private:
2721        pointer_t _data;
2722        sycl::range<3> _range;
2723    };
2724    template <class T, memory_region Memory> class accessor<T, Memory, 2> {
2725    public:
2726        using memory_t = detail::memory_traits<Memory, T>;
2727        using element_t = typename memory_t::element_t;
2728        using pointer_t = typename memory_t::pointer_t;
2729        using accessor_t = typename memory_t::template accessor_t<2>;
2730        accessor(pointer_t data, const sycl::range<2> &in_range)
2731            : _data(data), _range(in_range) {}
2732        template <memory_region M = Memory>
2733        accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2734            : accessor(acc, acc.get_range()) {}
2735        accessor(const accessor_t &acc, const sycl::range<2> &in_range)
2736            : accessor(acc.get_pointer(), in_range) {}
2737
2738        pointer_t operator[](size_t index) const {
2739            return _data + _range.get(1) * index;
2740        }
2741
2742        pointer_t get_ptr() const { return _data; }
2743
2744    private:
2745        pointer_t _data;
2746        sycl::range<2> _range;
2747    };
2748
2749    namespace detail {
2750        /// Device variable with address space of shared, global or constant.
2751        template <class T, memory_region Memory, size_t Dimension> class device_memory {
2752        public:
2753            using accessor_t =
2754                typename detail::memory_traits<Memory,
2755                                            T>::template accessor_t<Dimension>;
2756            using value_t = typename detail::memory_traits<Memory, T>::value_t;
2757            using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
2758
2759            device_memory() : device_memory(sycl::range<Dimension>(1)) {}
2760
2761            /// Constructor of 1-D array with initializer list
2762            device_memory(const sycl::range<Dimension> &in_range,
2763                        std::initializer_list<value_t> &&init_list)
2764                : device_memory(in_range) {
2765                assert(init_list.size() <= in_range.size());
2766                _host_ptr = (value_t *)std::malloc(_size);
2767                std::memset(_host_ptr, 0, _size);
2768                std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
2769            }
2770
2771            /// Constructor of 2-D array with initializer list
2772            template <size_t D = Dimension>
2773            device_memory(
2774                const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
2775                std::initializer_list<std::initializer_list<value_t>> &&init_list)
2776                : device_memory(in_range) {
2777                assert(init_list.size() <= in_range[0]);
2778                _host_ptr = (value_t *)std::malloc(_size);
2779                std::memset(_host_ptr, 0, _size);
2780                auto tmp_data = _host_ptr;
2781                for (auto sub_list : init_list) {
2782                    assert(sub_list.size() <= in_range[1]);
2783                    std::memcpy(tmp_data, sub_list.begin(),
2784                                sub_list.size() * sizeof(T));
2785                    tmp_data += in_range[1];
2786                }
2787            }
2788
2789            /// Constructor with range
2790            device_memory(const sycl::range<Dimension> &range_in)
2791                : _size(range_in.size() * sizeof(T)), _range(range_in),
2792                _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
2793                static_assert(
2794                    (Memory == global) || (Memory == constant) || (Memory == shared),
2795                    "device memory region should be global, constant or shared");
2796                // Make sure that singleton class mem_mgr and dev_mgr will destruct
2797                // later than this.
2798                detail::mem_mgr::instance();
2799                dev_mgr::instance();
2800            }
2801
2802            /// Constructor with range
2803            template <class... Args>
2804            device_memory(Args... Arguments)
2805                : device_memory(sycl::range<Dimension>(Arguments...)) {}
2806
2807            ~device_memory() {
2808                if (_device_ptr && !_reference)
2809                    dpct::dpct_free(_device_ptr);
2810                if (_host_ptr)
2811                    std::free(_host_ptr);
2812            }
2813
2814            /// Allocate memory with default queue, and init memory if has initial
2815            /// value.
2816            void init() { init(dpct::get_default_queue()); }
2817            /// Allocate memory with specified queue, and init memory if has initial
2818            /// value.
2819            void init(sycl::queue &q) {
2820                if (_device_ptr)
2821                    return;
2822                if (!_size)
2823                    return;
2824                allocate_device(q);
2825                if (_host_ptr)
2826                    detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
2827                                        host_to_device);
2828            }
2829
2830            /// The variable is assigned to a device pointer.
2831            void assign(value_t *src, size_t size) {
2832                this->~device_memory();
2833                new (this) device_memory(src, size);
2834            }
2835
2836            /// Get memory pointer of the memory object, which is virtual pointer when
2837            /// usm is not used, and device pointer when usm is used.
2838            value_t *get_ptr() { return get_ptr(get_default_queue()); }
2839            /// Get memory pointer of the memory object, which is virtual pointer when
2840            /// usm is not used, and device pointer when usm is used.
2841            value_t *get_ptr(sycl::queue &q) {
2842                init(q);
2843                return _device_ptr;
2844            }
2845
2846            /// Get the device memory object size in bytes.
2847            size_t get_size() { return _size; }
2848
2849            template <size_t D = Dimension>
2850            typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
2851                init();
2852                return _device_ptr[index];
2853            }
2854
2855            /// Get dpct::accessor with dimension info for the device memory object
2856            /// when usm is used and dimension is greater than 1.
2857            template <size_t D = Dimension>
2858            typename std::enable_if<D != 1, dpct_accessor_t>::type
2859            get_access([[maybe_unused]] sycl::handler &cgh) {
2860                return dpct_accessor_t((T *)_device_ptr, _range);
2861            }
2862
2863        private:
2864            device_memory(value_t *memory_ptr, size_t size)
2865                : _size(size), _range(size / sizeof(T)), _reference(true),
2866                _device_ptr(memory_ptr) {}
2867
2868            void allocate_device(sycl::queue &q) {
2869        #ifndef DPCT_USM_LEVEL_NONE
2870                if (Memory == shared) {
2871                    _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
2872                                                                q.get_context());
2873                    return;
2874                }
2875        #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
2876                if (Memory == constant) {
2877                    _device_ptr = (value_t *)sycl::malloc_device(
2878                        _size, q.get_device(), q.get_context(),
2879                        sycl::ext::oneapi::property::usm::device_read_only());
2880                    return;
2881                }
2882        #endif
2883        #endif
2884                _device_ptr = (value_t *)detail::dpct_malloc(_size, q);
2885            }
2886
2887            size_t _size;
2888            sycl::range<Dimension> _range;
2889            bool _reference;
2890            value_t *_host_ptr;
2891            value_t *_device_ptr;
2892        };
2893        template <class T, memory_region Memory>
2894        class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
2895        public:
2896            using base = device_memory<T, Memory, 1>;
2897            using value_t = typename base::value_t;
2898            using accessor_t =
2899                typename detail::memory_traits<Memory, T>::template accessor_t<0>;
2900
2901            /// Constructor with initial value.
2902            device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
2903
2904            /// Default constructor
2905            device_memory() : base(1) {}
2906        };
2907        } // namespace detail
2908
2909    template <class T, size_t Dimension>
2910    using global_memory = detail::device_memory<T, global, Dimension>;
2911    template <class T, size_t Dimension>
2912    using constant_memory = detail::device_memory<T, constant, Dimension>;
2913    template <class T, size_t Dimension>
2914    using shared_memory = detail::device_memory<T, shared, Dimension>;
2915
2916
2917    template <typename T,
2918            sycl::access::address_space addressSpace =
2919                sycl::access::address_space::global_space,
2920            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2921            sycl::memory_scope memoryScope = sycl::memory_scope::device>
2922    inline T atomic_fetch_add(T *addr, T operand) {
2923    auto atm =
2924        sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2925    return atm.fetch_add(operand);
2926    }
2927
2928    template <sycl::access::address_space addressSpace =
2929                sycl::access::address_space::global_space,
2930            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2931            sycl::memory_scope memoryScope = sycl::memory_scope::device,
2932            typename T1, typename T2>
2933    inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2934    auto atm =
2935        sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2936    return atm.fetch_add(operand);
2937    }
2938
2939    template <typename T, sycl::access::address_space addressSpace =
2940                            sycl::access::address_space::global_space>
2941    inline T atomic_fetch_add(T *addr, T operand,
2942                            sycl::memory_order memoryOrder) {
2943    switch (memoryOrder) {
2944        case sycl::memory_order::relaxed:
2945            return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2946                                    sycl::memory_scope::device>(addr, operand);
2947        case sycl::memory_order::acq_rel:
2948            return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2949                                    sycl::memory_scope::device>(addr, operand);
2950        case sycl::memory_order::seq_cst:
2951            return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2952                                    sycl::memory_scope::device>(addr, operand);
2953        default:
2954            assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2955                            "atomics are: sycl::memory_order::relaxed, "
2956                            "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2957        }
2958    }
2959
2960    template <sycl::access::address_space addressSpace =
2961                sycl::access::address_space::global_space,
2962            typename T1, typename T2>
2963    inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2964                            sycl::memory_order memoryOrder) {
2965    atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2966    }
2967
2968    inline unsigned int byte_level_permute(
2969        unsigned int a, unsigned int b, unsigned int s) {
2970      unsigned int ret;
2971      ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
2972            (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
2973             << 8) |
2974            (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
2975             << 16) |
2976            (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
2977             << 24);
2978      return ret;
2979    }
2980
2981    inline uint32_t byte_level_permute_custom(
2982        uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
2983      constexpr uint16_t lookup[6][4] = {
2984          {0x3210, 0x4321, 0x5432, 0x6543},  // Forward 4-byte extract
2985          {0x5670, 0x6701, 0x7012, 0x0123},  // Backward 4-byte extract
2986          {0x0000, 0x1111, 0x2222, 0x3333},  // Replicate 8-bit values
2987          {0x3210, 0x3211, 0x3222, 0x3333},  // Edge clamp left
2988          {0x0000, 0x1110, 0x2210, 0x3210},  // Edge clamp right
2989          {0x1010, 0x3232, 0x1010, 0x3232}   // Replicate 16-bit values
2990      };
2991
2992      if (mode >= 1 && mode <= 6) {
2993        return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
2994      } else if (!mode) {
2995        return byte_level_permute(low32, high32, sel);
2996      }
2997      return 0;
2998    }
2999
3000} // COPY from DPCT head files
3001
3002#endif // GGML_SYCL_DPCT_HELPER_HPP