1//
  2// MIT license
  3// Copyright (C) 2024 Intel Corporation
  4// SPDX-License-Identifier: MIT
  5//
  6
  7//
  8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9// See https://llvm.org/LICENSE.txt for license information.
 10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 11//
 12
 13#ifndef GGML_SYCL_COMMON_HPP
 14#define GGML_SYCL_COMMON_HPP
 15
 16#include <cstddef>
 17#include <fstream>
 18#include <iostream>
 19#include <string>
 20
 21#include "dpct/helper.hpp"
 22#include "ggml-sycl.h"
 23#include "presets.hpp"
 24#include "sycl_hw.hpp"
 25
 26
 27#if GGML_SYCL_DNNL
 28#include "dnnl.hpp"
 29#include "dnnl_sycl.hpp"
 30#endif
 31
 32#define GGML_COMMON_DECL_SYCL
 33#define GGML_COMMON_IMPL_SYCL
 34/* suppress warning spam */
 35#pragma clang diagnostic push
 36#pragma clang diagnostic ignored "-Wnested-anon-types"
 37#include "ggml-common.h"
 38#pragma clang diagnostic pop
 39#include "ggml-impl.h"
 40
 41void* ggml_sycl_host_malloc(size_t size);
 42void ggml_sycl_host_free(void* ptr);
 43
 44
 45extern int g_ggml_sycl_debug;
 46extern int g_ggml_sycl_disable_optimize;
 47extern int g_ggml_sycl_prioritize_dmmv;
 48
 49#if defined(__clang__) && __has_builtin(__builtin_expect)
 50// Hint the optimizer to pipeline the more likely following instruction in branches
 51#    define LIKELY(expr)   __builtin_expect(expr, true)
 52#    define UNLIKELY(expr) __builtin_expect(expr, false)
 53#else
 54#    define LIKELY(expr)   (expr)
 55#    define UNLIKELY(expr) (expr)
 56#endif
 57
 58#define GGML_SYCL_DEBUG(...)              \
 59    do {                                  \
 60        if (UNLIKELY(g_ggml_sycl_debug))  \
 61            fprintf(stderr, __VA_ARGS__); \
 62    } while (0)
 63
 64#define CHECK_TRY_ERROR(expr)                                            \
 65  [&]() {                                                                \
 66    try {                                                                \
 67      expr;                                                              \
 68      return dpct::success;                                              \
 69    } catch (std::exception const& e) {                                  \
 70      std::cerr << e.what() << "\nException caught at file:" << __FILE__ \
 71                << ", line:" << __LINE__ << ", func:" << __func__        \
 72                << std::endl;                                            \
 73      return dpct::default_error;                                        \
 74    }                                                                    \
 75  }()
 76
 77
 78#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
 79#define VER_4VEC 610 // todo for hardward optimize.
 80#define VER_GEN9 700 // todo for hardward optimize.
 81#define VER_GEN12 1000000 // todo for hardward optimize.
 82#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize.
 83
 84#define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares
 85
 86// define for XMX in Intel GPU
 87// TODO: currently, it's not used for XMX really.
 88#if !defined(GGML_SYCL_FORCE_MMQ)
 89    #define SYCL_USE_XMX
 90#endif
 91
 92// max batch size to use MMQ kernels when tensor cores are available
 93#define MMQ_MAX_BATCH_SIZE 32
 94
 95// dmmv = dequantize_mul_mat_vec
 96#ifndef GGML_SYCL_DMMV_X
 97#define GGML_SYCL_DMMV_X 32
 98#endif
 99#ifndef GGML_SYCL_MMV_Y
100#define GGML_SYCL_MMV_Y 1
101#endif
102
103typedef sycl::queue *queue_ptr;
104
105enum ggml_sycl_backend_gpu_mode {
106  SYCL_UNSET_GPU_MODE = -1,
107  SYCL_SINGLE_GPU_MODE = 0,
108  SYCL_MUL_GPU_MODE
109};
110
111static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size");
112
113static void crash() {
114  int* ptr = NULL;
115  *ptr = 0;
116}
117
118[[noreturn]] static void ggml_sycl_error(
119    const char* stmt,
120    const char* func,
121    const char* file,
122    const int line,
123    const char* msg) {
124  fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg);
125  fprintf(stderr, "  in function %s at %s:%d\n", func, file, line);
126  GGML_ABORT("SYCL error");
127}
128
129#define SYCL_CHECK(err)                                                                                    \
130    do {                                                                                                   \
131        auto err_ = (err);                                                                                 \
132        if (err_ != 0)                                                                                     \
133            ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
134    } while (0)
135
136#if DPCT_COMPAT_RT_VERSION >= 11100
137#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
138#else
139#define GGML_SYCL_ASSUME(x)
140#endif // DPCT_COMPAT_RT_VERSION >= 11100
141
142#ifdef GGML_SYCL_F16
143typedef sycl::half dfloat; // dequantize float
144typedef sycl::half2 dfloat2;
145#else
146typedef float dfloat; // dequantize float
147typedef sycl::float2 dfloat2;
148#endif // GGML_SYCL_F16
149
150#define MMVQ_MAX_BATCH_SIZE  8
151
152static int g_all_sycl_device_count = -1;
153static bool g_ggml_backend_sycl_buffer_type_initialized = false;
154
155static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode =
156    SYCL_UNSET_GPU_MODE;
157
158static void* g_scratch_buffer = nullptr;
159static size_t g_scratch_size = 0; // disabled by default
160static size_t g_scratch_offset = 0;
161
162[[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) {
163  stream_ct1 << "ERROR: ggml-sycl was compiled without support for the "
164                "current GPU architecture.\n";
165  // __trap();
166  std::exit(1);
167
168  (void)bad_arch; // suppress unused function warning
169}
170
171int get_current_device_id();
172
173inline dpct::err0 ggml_sycl_set_device(const int device) try {
174  int current_device_id;
175  SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
176
177  // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d,
178  // current_device_id=%d\n", device, current_device);
179  if (device == current_device_id) {
180    return 0;
181  }
182
183  return CHECK_TRY_ERROR(dpct::select_device(device));
184} catch (sycl::exception const& exc) {
185  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
186            << ", line:" << __LINE__ << std::endl;
187  crash();
188  std::exit(1);
189}
190
191//////////////////////
192struct optimize_feature {
193    bool reorder=false;
194};
195
196struct sycl_device_info {
197    int     cc;                 // compute capability
198    int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
199             // number of compute units on a SYCL device.
200    // size_t  smpb;               // max. shared memory per block
201    size_t  smpbo;              // max. shared memory per block (with opt-in)
202    bool    vmm;                // virtual memory support
203    size_t  total_vram;
204    //sycl_hw_info hw_info;     \\ device id and aarch, currently not used
205    optimize_feature opt_feature;
206};
207
208
209struct ggml_sycl_device_info {
210    int device_count;
211
212    sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
213
214    std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
215
216    int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
217};
218
219const ggml_sycl_device_info & ggml_sycl_info();
220
221struct ggml_sycl_pool {
222    virtual ~ggml_sycl_pool() = default;
223
224    virtual void * alloc(size_t size, size_t * actual_size) = 0;
225    virtual void free(void * ptr, size_t size) = 0;
226};
227
228template<typename T>
229struct ggml_sycl_pool_alloc {
230    ggml_sycl_pool * pool = nullptr;
231    T * ptr = nullptr;
232    size_t actual_size = 0;
233
234    explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) {
235    }
236
237    ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) {
238        alloc(size);
239    }
240
241    ~ggml_sycl_pool_alloc() {
242        if (ptr != nullptr) {
243            pool->free(ptr, actual_size);
244        }
245    }
246
247    T * realloc(size_t size) {
248        GGML_ASSERT(pool != nullptr);
249        if (ptr)
250            pool->free(ptr, actual_size);
251        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
252        return ptr;
253    }
254
255    // size is in number of elements
256    T * alloc(size_t size) {
257        GGML_ASSERT(pool != nullptr);
258        GGML_ASSERT(ptr == nullptr);
259        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
260        return ptr;
261    }
262
263    T * alloc(ggml_sycl_pool & pool, size_t size) {
264        this->pool = &pool;
265        return alloc(size);
266    }
267
268    T * get() {
269        return ptr;
270    }
271
272    ggml_sycl_pool_alloc() = default;
273    ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete;
274    ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete;
275    ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete;
276    ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete;
277};
278
279// backend interface
280
281struct ggml_tensor_extra_gpu {
282  void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split
283                                       // tensors
284  dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
285                        [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
286  optimize_feature optimized_feature;
287};
288
289void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
290
291namespace sycl_ex = sycl::ext::oneapi::experimental;
292struct ggml_backend_sycl_context {
293    int device;
294    std::string name;
295    optimize_feature opt_feature;
296
297    queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
298
299    explicit ggml_backend_sycl_context(int device) :
300        device(device),
301        name(GGML_SYCL_NAME + std::to_string(device)) {
302        opt_feature = ggml_sycl_info().devices[device].opt_feature;
303    }
304
305    queue_ptr stream(int device, int stream) {
306        if (qptrs[device][stream] == nullptr) {
307            qptrs[device][stream] = &(dpct::get_device(device).default_queue());
308        }
309        return qptrs[device][stream];
310    }
311
312    queue_ptr stream() {
313        return stream(device, 0);
314    }
315
316#if GGML_SYCL_DNNL
317    dnnl::engine make_engine(sycl::queue* q) {
318        // Get the device associated with the queue
319        sycl::device dev = q->get_device();
320        // Get the context associated with the queue
321        sycl::context ctx = q->get_context();
322        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
323        return eng;
324    }
325
326    std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
327    std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
328    dnnl::stream stream_dnnl(int device, int _stream) {
329        auto q = stream(device, _stream);
330        return stream_dnnl(q);
331    }
332    dnnl::engine engine_dnnl(sycl::queue* qptr) {
333        auto it = engine_map.find(qptr);
334        if (it == engine_map.end()) {
335            auto eng = make_engine(qptr);
336            engine_map[qptr] = eng;
337            return eng;
338        }
339        else
340        {
341            return it->second;
342        }
343    }
344    dnnl::stream stream_dnnl(sycl::queue* qptr) {
345        auto it = stream_map.find(qptr);
346        if (it == stream_map.end()) {
347            auto eng = engine_dnnl(qptr);
348            auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
349            stream_map[qptr] = stream;
350            return stream;
351        }
352        else
353        {
354            return it->second;
355        }
356    }
357    dnnl::stream stream_dnnl() {
358        return stream_dnnl(device, 0);
359    }
360    dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
361                                    const dnnl::engine & eng, const queue_ptr q) {
362        ggml_sycl_pool_alloc<uint8_t> * pool;
363        auto it = scratchpad_map.find(q);
364        if (it == scratchpad_map.end()) {
365            scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
366            pool = scratchpad_map[q].get();
367        } else {
368            pool = it->second.get();
369        }
370
371        size_t scratchpad_size = scratchpad_md.get_size();
372        if (scratchpad_size > pool->actual_size) {
373            pool->realloc(scratchpad_size);
374        }
375        void * mem_ptr = pool->get();
376        return dnnl::memory(scratchpad_md, eng, mem_ptr);
377    }
378#endif
379
380    // pool
381    std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
382    std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
383
384    std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
385
386    static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
387
388    static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
389
390    ggml_sycl_pool & pool(int device) {
391        if (pools[device] == nullptr) {
392            pools[device] = new_pool_for_device(stream(device,0), device);
393        }
394        return *pools[device];
395    }
396
397    ggml_sycl_pool & pool() {
398        return pool(device);
399    }
400
401#ifdef GGML_SYCL_GRAPH
402    std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
403#endif
404
405    ggml_sycl_pool & host_pool(int device) {
406        if (host_pools[device] == nullptr) {
407            host_pools[device] = new_pool_for_host(stream(device, 0), device);
408        }
409        return *host_pools[device];
410    }
411
412    ggml_sycl_pool & host_pool() { return host_pool(device); }
413};
414
415// common device functions
416
417static __dpct_inline__ float warp_reduce_sum(float x,
418    const sycl::nd_item<3>& item_ct1) {
419#pragma unroll
420    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
421        x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
422    }
423    return x;
424}
425
426static __dpct_inline__ sycl::float2
427warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
428#pragma unroll
429    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
430        a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
431            mask);
432        a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
433            mask);
434    }
435    return a;
436}
437
438template <int width = WARP_SIZE>
439static __dpct_inline__ int warp_reduce_sum(int x) {
440  return sycl::reduce_over_group(
441      sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
442}
443
444template <int width = WARP_SIZE>
445static __dpct_inline__ float warp_reduce_sum(float x) {
446#pragma unroll
447  for (int offset = width / 2; offset > 0; offset >>= 1) {
448    x += dpct::permute_sub_group_by_xor(
449        sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
450  }
451  return x;
452}
453
454template <int width = WARP_SIZE>
455static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
456#pragma unroll
457  for (int offset = width / 2; offset > 0; offset >>= 1) {
458    a.x() += dpct::permute_sub_group_by_xor(
459        sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
460        width);
461    a.y() += dpct::permute_sub_group_by_xor(
462        sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
463        width);
464  }
465  return a;
466}
467
468template <int width = WARP_SIZE>
469static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
470#pragma unroll
471  for (int offset = width / 2; offset > 0; offset >>= 1) {
472    a = a + dpct::permute_sub_group_by_xor(
473                sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
474                width);
475  }
476  return a;
477}
478
479static constexpr int ggml_sycl_get_physical_warp_size() {
480  // todo: for old iGPU + dGPU case, need to be changed.
481  return WARP_SIZE;
482}
483
484template <int width = WARP_SIZE>
485static __dpct_inline__ float warp_reduce_max(float x) {
486#pragma unroll
487  for (int offset = width / 2; offset > 0; offset >>= 1) {
488    x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
489                          sycl::ext::oneapi::this_work_item::get_sub_group(), x,
490                          offset, width));
491  }
492  return x;
493}
494
495static __dpct_inline__ float warp_reduce_max(float x,
496    const sycl::nd_item<3>& item_ct1) {
497#pragma unroll
498    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
499        x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
500            item_ct1.get_sub_group(), x, mask));
501    }
502    return x;
503}
504
505/* Helper for Computing the linear offset of a ggml_tensor given
506per-dimension sizes, strides, and indices */
507template<int N>
508__dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {
509    size_t offset = 0;
510#pragma unroll
511    for (int i = 0; i < N; i++) {
512        auto index_i = indices[i];
513        offset += strides[i] * index_i;
514    }
515    return offset;
516}
517
518// Helper for vec loading aligned data
519template <typename Tp, int n>
520inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
521    return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
522}
523
524// Helper for accessing pointers with no warnings
525template <typename Tp, int dim>
526static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
527    return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
528}
529
530int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
531
532constexpr size_t ceil_div(const size_t m, const size_t n) {
533    return (m + n - 1) / n;
534}
535
536bool gpu_has_xmx(sycl::device &dev);
537
538template <int N, class T> std::string debug_get_array_str(const std::string & prefix, const T array[N]) {
539    if (LIKELY(!g_ggml_sycl_debug)) {
540        return "";
541    }
542    std::stringstream ss;
543    ss << prefix << "=[";
544    for (std::size_t i = 0; i < N - 1; ++i) {
545        ss << array[i] << ", ";
546    }
547    if constexpr (N > 0) {
548        ss << array[N - 1];
549    }
550    ss << "]";
551    return ss.str();
552}
553
554inline std::string debug_get_tensor_str(const std::string &prefix,
555        const ggml_tensor *tensor, const std::string &suffix = "") {
556    std::stringstream ss;
557    if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); }
558    ss << prefix.c_str() << "=";
559    if (tensor) {
560        ss << "'" << tensor->name << "':type=" << ggml_type_name(tensor->type);
561        ss << debug_get_array_str<GGML_MAX_DIMS>(";ne", tensor->ne);
562        ss << debug_get_array_str<GGML_MAX_DIMS>(";nb", tensor->nb);
563
564        if (!ggml_is_contiguous(tensor)) { ss << ";strided"; }
565        if (ggml_is_permuted(tensor)) { ss << ";permuted"; }
566    } else {
567        ss << "nullptr";
568    }
569    ss << suffix;
570    return ss.str();
571}
572
573// Use scope_op_debug_print to log operations coming from running a model
574struct scope_op_debug_print {
575    // Use string_views to avoid the cost of creating a string and concatenating them
576    // string_views must be alive for as long as the object is alive
577    // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible
578    scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst,
579                         std::size_t num_src, const std::string_view & suffix = "") :
580        func(func),
581        func_suffix(func_suffix) {
582        if (LIKELY(!g_ggml_sycl_debug)) {
583            return;
584        }
585        GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data());
586        GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" dst", dst).c_str());
587        if (dst) {
588            for (std::size_t i = 0; i < num_src; ++i) {
589                GGML_SYCL_DEBUG("%s", debug_get_tensor_str("\tsrc" + std::to_string(i), dst->src[i]).c_str());
590            }
591        }
592        GGML_SYCL_DEBUG("%s\n", suffix.data());
593    }
594
595    scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src,
596                         const std::string_view & suffix = "") :
597        scope_op_debug_print(func, "", dst, num_src, suffix) {}
598
599    ~scope_op_debug_print() { GGML_SYCL_DEBUG("[SYCL][OP] call %s%s done\n", func.data(), func_suffix.data()); }
600
601  private:
602    std::string_view func;
603    std::string_view func_suffix;
604};
605
606static __dpct_inline__ float get_alibi_slope(const float    max_bias,
607                                             const uint32_t h,
608                                             const uint32_t n_head_log2,
609                                             const float    m0,
610                                             const float    m1) {
611    if (max_bias <= 0.0f) {
612        return 1.0f;
613    }
614    const float base = h < n_head_log2 ? m0 : m1;
615    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
616
617    return dpct::pow(base, exph);
618}
619
620static const sycl::uint3 init_fastdiv_values(uint32_t d) {
621    GGML_ASSERT(d != 0);
622
623    uint32_t L = 0;
624    while (L < 32 && (uint32_t{ 1 } << L) < d) {
625        L++;
626    }
627
628    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
629    return sycl::uint3(mp, L, d);
630}
631
632
633static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
634    const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
635    return (hi + n) >> fastdiv_values.y();
636}
637
638
639static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
640    const uint32_t div_val = fastdiv(n, fastdiv_values);
641    const uint32_t mod_val = n - div_val * fastdiv_values.z();
642    return sycl::uint2(div_val, mod_val);
643}
644
645static __dpct_inline__ int ggml_sycl_dp4a(const int a, const int b, int c) {
646    return dpct::dp4a(a, b, c);
647}
648
649static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
650    uint32_t bits;
651    if (x == 0) {
652        bits = 0x00400000;
653    } else {
654        bits = (uint32_t) x << 23;
655    }
656
657    float result;
658    memcpy(&result, &bits, sizeof(float));
659    return result;
660}
661
662
663#endif // GGML_SYCL_COMMON_HPP