1//
   2// MIT license
   3// Copyright (C) 2024 Intel Corporation
   4// SPDX-License-Identifier: MIT
   5//
   6
   7//
   8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
   9// See https://llvm.org/LICENSE.txt for license information.
  10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11//
  12
  13#include <algorithm>
  14#include <assert.h>
  15#include <atomic>
  16#include <cinttypes>
  17#include <cstddef>
  18#include <cstdint>
  19#include <cstdlib>
  20#include <float.h>
  21#include <limits>
  22#include <stdint.h>
  23#include <stdio.h>
  24#include <vector>
  25#include <cmath>
  26#include <iostream>
  27#include <fstream>
  28#include <stdio.h>
  29#include <stdlib.h>
  30#include <regex>
  31
  32#include <sycl/sycl.hpp>
  33#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
  34#    include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
  35#endif
  36#include <sycl/half_type.hpp>
  37
  38#include "ggml-sycl.h"
  39#include "ggml-impl.h"
  40#include "ggml-backend-impl.h"
  41
  42#include "ggml-sycl/add-id.hpp"
  43#include "ggml-sycl/backend.hpp"
  44#include "ggml-sycl/common.hpp"
  45#include "ggml-sycl/element_wise.hpp"
  46#include "ggml-sycl/norm.hpp"
  47#include "ggml-sycl/presets.hpp"
  48#include "ggml-sycl/gemm.hpp"
  49#include "ggml-sycl/set_rows.hpp"
  50#include "ggml-sycl/set.hpp"
  51#include "ggml-sycl/sycl_hw.hpp"
  52#include "ggml-sycl/getrows.hpp"
  53#include "ggml-sycl/repeat_back.hpp"
  54#include "ggml-sycl/quantize.hpp"
  55#include "ggml-sycl/ssm_conv.hpp"
  56#include "ggml.h"
  57
  58static bool g_sycl_loaded = false;
  59int g_ggml_sycl_debug = 0;
  60int g_ggml_sycl_disable_optimize = 0;
  61int g_ggml_sycl_disable_graph = 0;
  62int g_ggml_sycl_disable_dnn = 0;
  63int g_ggml_sycl_prioritize_dmmv = 0;
  64int g_ggml_sycl_use_async_mem_op = 0;
  65
  66static ggml_sycl_device_info ggml_sycl_init() {
  67    ggml_sycl_device_info info = {};
  68
  69    info.device_count = dpct::dev_mgr::instance().device_count();
  70    if (info.device_count == 0) {
  71        GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__);
  72        return info;
  73    }
  74
  75    GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
  76
  77    int64_t total_vram = 0;
  78/* This is a bit misleading;  reserved for later */
  79// #if defined(SYCL_USE_XMX)
  80//     GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
  81// #else
  82//     GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
  83// #endif
  84    for (int i = 0; i < info.device_count; ++i) {
  85        info.devices[i].vmm = 0;
  86        dpct::device_info prop;
  87        sycl::device device = dpct::dev_mgr::instance().get_device(i);
  88
  89        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
  90            prop, device)));
  91
  92        info.default_tensor_split[i] = total_vram;
  93        total_vram += prop.get_global_mem_size();
  94
  95        info.devices[i].cc =
  96            100 * prop.get_major_version() + 10 * prop.get_minor_version();
  97        info.devices[i].nsm = prop.get_max_compute_units();
  98        info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
  99        info.devices[i].smpbo = prop.get_local_mem_size();
 100
 101        info.max_work_group_sizes[i] = prop.get_max_work_group_size();
 102    }
 103
 104    for (int id = 0; id < info.device_count; ++id) {
 105        info.default_tensor_split[id] /= total_vram;
 106    }
 107    return info;
 108}
 109
 110const ggml_sycl_device_info & ggml_sycl_info() {
 111    static ggml_sycl_device_info info = ggml_sycl_init();
 112    return info;
 113}
 114
 115static void print_device_detail(int id, sycl::device &device, std::string device_type) {
 116
 117    dpct::device_info prop;
 118    SYCL_CHECK(CHECK_TRY_ERROR(
 119        dpct::get_device_info(prop, device)));
 120
 121    std::string version;
 122    version += std::to_string(prop.get_major_version());
 123    version += ".";
 124    version += std::to_string(prop.get_minor_version());
 125
 126    device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
 127    std::string name = std::string(prop.get_name());
 128    name = std::regex_replace(name, std::regex("\\(R\\)"), "");
 129    name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
 130
 131    auto global_mem_size = prop.get_global_mem_size()/1000000;
 132    GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
 133            name.c_str(), version.c_str(), prop.get_max_compute_units(),
 134            prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
 135            global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
 136}
 137
 138static void print_device_opt_feature(int device_count) {
 139    GGML_LOG_INFO("SYCL Optimization Feature:\n");
 140    GGML_LOG_INFO(
 141        "|ID|        Device Type|Reorder|\n");
 142    GGML_LOG_INFO(
 143        "|--|-------------------|-------|\n");
 144    std::map<std::string, size_t> DeviceNums;
 145    for (int id = 0; id < device_count; ++id) {
 146      sycl::device device = dpct::dev_mgr::instance().get_device(id);
 147      std::string backend_type = get_device_backend_and_type(device);
 148      int type_id = DeviceNums[backend_type]++;
 149      std::stringstream device_type;
 150      device_type << "[" << backend_type << ":" << std::to_string(type_id)
 151                  << "]";
 152      std::string device_type_s = device_type.str();
 153      device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
 154      GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
 155        ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
 156    }
 157
 158}
 159void ggml_backend_sycl_print_sycl_devices() {
 160    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
 161    int device_count = dpct::dev_mgr::instance().device_count();
 162    std::map<std::string, size_t> DeviceNums;
 163    GGML_LOG_INFO("Found %d SYCL devices:\n", device_count);
 164
 165    GGML_LOG_INFO(
 166        "|  |                   |                                       |      "
 167        " |Max    |        |Max  |Global |                     |\n");
 168    GGML_LOG_INFO(
 169        "|  |                   |                                       |      "
 170        " |compute|Max work|sub  |mem    |                     |\n");
 171    GGML_LOG_INFO(
 172        "|ID|        Device Type|                                   "
 173        "Name|Version|units  |group   |group|size   |       Driver version|\n");
 174    GGML_LOG_INFO(
 175        "|--|-------------------|---------------------------------------|------"
 176        "-|-------|--------|-----|-------|---------------------|\n");
 177
 178    for (int id = 0; id < device_count; ++id) {
 179      sycl::device device = dpct::dev_mgr::instance().get_device(id);
 180      std::string backend_type = get_device_backend_and_type(device);
 181      int type_id = DeviceNums[backend_type]++;
 182      std::stringstream device_type;
 183      device_type << "[" << backend_type << ":" << std::to_string(type_id)
 184                  << "]";
 185      print_device_detail(id, device, device_type.str());
 186    }
 187
 188    print_device_opt_feature(device_count);
 189}
 190
 191static inline int get_sycl_env(const char *env_name, int default_val) {
 192    char *user_device_string = getenv(env_name);
 193    int user_number = default_val;
 194
 195    unsigned n;
 196    if (user_device_string != NULL &&
 197        sscanf(user_device_string, " %u", &n) == 1) {
 198        user_number = (int)n;
 199    } else {
 200        user_number = default_val;
 201    }
 202    return user_number;
 203}
 204
 205static void ggml_check_sycl() try {
 206    static bool initialized = false;
 207
 208    if (!initialized) {
 209        g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
 210        g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
 211        g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
 212        g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
 213        g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
 214        GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
 215        GGML_LOG_INFO("Running with Environment Variables:\n");
 216        GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
 217        GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
 218#ifdef GGML_SYCL_GRAPH
 219        GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
 220#else
 221        GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
 222#endif
 223#if GGML_SYCL_DNNL
 224        GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
 225#else
 226        GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
 227#endif
 228        GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
 229        GGML_LOG_INFO("Build with Macros:\n");
 230#if defined(GGML_SYCL_FORCE_MMQ)
 231        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
 232#else
 233        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
 234#endif
 235#if defined(GGML_SYCL_F16)
 236        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
 237#else
 238        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
 239#endif
 240
 241/* NOT REMOVE, keep it for next optimize for XMX.
 242#if defined(SYCL_USE_XMX)
 243        fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
 244#else
 245        fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
 246#endif
 247*/
 248        // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
 249        // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
 250        // other places.
 251#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
 252        g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
 253        if (g_ggml_sycl_use_async_mem_op) {
 254            for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
 255                if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
 256                    g_ggml_sycl_use_async_mem_op = 0;
 257                    break;
 258                }
 259            }
 260        }
 261#endif
 262        if (CHECK_TRY_ERROR(g_all_sycl_device_count =
 263                            dpct::dev_mgr::instance().device_count()) != 0) {
 264            initialized = true;
 265            g_sycl_loaded = false;
 266            return;
 267        }
 268        GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
 269
 270        initialized = true;
 271        g_sycl_loaded = true;
 272        ggml_backend_sycl_print_sycl_devices();
 273    }
 274}
 275catch (sycl::exception const &exc) {
 276  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 277            << ", line:" << __LINE__ << std::endl;
 278  std::exit(1);
 279}
 280
 281/*
 282device_index: device index from 0 to n (continue numbers).
 283    It is used for device select/set in SYCL backend internal data structure.
 284*/
 285inline void check_allow_gpu_index(const int device_index) {
 286  if (device_index >= ggml_sycl_info().device_count) {
 287    char error_buf[256];
 288    snprintf(
 289        error_buf,
 290        sizeof(error_buf),
 291        "%s error: device_index:%d is out of range: [0-%d]",
 292        __func__,
 293        device_index,
 294        ggml_sycl_info().device_count - 1);
 295    GGML_LOG_ERROR("%s\n", error_buf);
 296    assert(false);
 297  }
 298}
 299
 300GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
 301    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
 302    for(int i=0;i<max_len;i++) id_list[i] = -1;
 303
 304    for (int i=0;i< ggml_sycl_info().device_count;i++){
 305        if (i>=max_len) break;
 306        id_list[i] = i;
 307    }
 308    return;
 309}
 310catch (sycl::exception const &exc) {
 311  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 312            << ", line:" << __LINE__ << std::endl;
 313  std::exit(1);
 314}
 315
 316// sycl buffer
 317
 318struct ggml_backend_sycl_buffer_context {
 319    int device;
 320    void * dev_ptr = nullptr;
 321    queue_ptr stream;
 322    std::string name;
 323    optimize_feature opt_feature;
 324    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
 325
 326    ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
 327        device(device), dev_ptr(dev_ptr), stream(stream) {
 328            check_allow_gpu_index(device);
 329            name = (GGML_SYCL_NAME + std::to_string(device));
 330            opt_feature = ggml_sycl_info().devices[device].opt_feature;
 331        }
 332
 333    ~ggml_backend_sycl_buffer_context() {
 334        if (dev_ptr != nullptr) {
 335            ggml_sycl_set_device(device);
 336            SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
 337        }
 338
 339        //release extra used by tensors
 340        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
 341            release_extra_gpu(extra);
 342        }
 343
 344    }
 345};
 346
 347static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);
 348
 349static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
 350    return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;
 351}
 352
 353static void
 354ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
 355    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 356    ggml_sycl_set_device(ctx->device);
 357
 358    delete ctx;
 359}
 360catch (sycl::exception const &exc) {
 361  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 362            << ", line:" << __LINE__ << std::endl;
 363  std::exit(1);
 364}
 365
 366static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
 367    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 368    return ctx->dev_ptr;
 369}
 370
 371static enum ggml_status
 372ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
 373                                     ggml_tensor *tensor) try {
 374    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 375    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
 376    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
 377
 378    if (tensor->view_src != NULL) {
 379        assert(tensor->view_src->buffer->buft == buffer->buft);
 380        return GGML_STATUS_SUCCESS;
 381    }
 382    if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
 383        !g_ggml_sycl_disable_optimize) {
 384        ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
 385        tensor->extra                 = extra;
 386        ctx->tensor_extras.push_back(extra);  //used to release it when destroy ctx.
 387    }
 388
 389    if (ggml_is_quantized(tensor->type)) {
 390        // initialize padding to 0 to avoid possible NaN values
 391        size_t original_size = ggml_nbytes(tensor);
 392        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 393
 394        if (padded_size > original_size && tensor->view_src == nullptr) {
 395            SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
 396                (char *)tensor->data + original_size, 0,
 397                padded_size - original_size).wait()));
 398        }
 399    }
 400    return GGML_STATUS_SUCCESS;
 401}
 402catch (sycl::exception const &exc) {
 403  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 404            << ", line:" << __LINE__ << std::endl;
 405  std::exit(1);
 406}
 407
 408static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
 409                                                ggml_tensor *tensor,
 410                                                const void *data, size_t offset,
 411                                                size_t size) try {
 412    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 413    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
 414    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
 415    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 416    ggml_sycl_set_device(ctx->device);
 417    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
 418    SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
 419#ifndef _WIN32
 420    // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
 421    // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
 422    char * host_buf = (char *) malloc(size);
 423    memcpy(host_buf, data, size);
 424    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
 425    free(host_buf);
 426#else
 427    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
 428#endif
 429}
 430catch (sycl::exception const &exc) {
 431  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 432            << ", line:" << __LINE__ << std::endl;
 433  std::exit(1);
 434}
 435
 436static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
 437                                                const ggml_tensor *tensor,
 438                                                void *data, size_t offset,
 439                                                size_t size) try {
 440    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 441    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
 442    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
 443    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 444
 445    ggml_sycl_set_device(ctx->device);
 446    auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
 447
 448    SYCL_CHECK(CHECK_TRY_ERROR(
 449        stream.memcpy(data, (const char *)tensor->data + offset, size)
 450            .wait()));
 451}
 452catch (sycl::exception const &exc) {
 453  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 454            << ", line:" << __LINE__ << std::endl;
 455  std::exit(1);
 456}
 457
 458static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
 459                    const void *ptr_src, size_t size) {
 460    char *host_buf = (char *)malloc(size);
 461    q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
 462    q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
 463    free(host_buf);
 464}
 465
 466static bool
 467ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
 468                                    const ggml_tensor *src,
 469                                    ggml_tensor *dst) try {
 470    bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
 471    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 472    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
 473    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
 474    GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
 475    if (is_cpy_supported) {
 476        ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
 477        ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
 478
 479        ggml_sycl_set_device(src_ctx->device);
 480        /*
 481        DPCT1009:198: SYCL uses exceptions to report errors and does not use the
 482        error codes. The original code was commented out and a warning string
 483        was inserted. You need to rewrite this code.
 484        */
 485        SYCL_CHECK(CHECK_TRY_ERROR(
 486            dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
 487        ggml_sycl_set_device(dst_ctx->device);
 488        /*
 489        DPCT1009:199: SYCL uses exceptions to report errors and does not use the
 490        error codes. The original code was commented out and a warning string
 491        was inserted. You need to rewrite this code.
 492        */
 493        SYCL_CHECK(CHECK_TRY_ERROR(
 494            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
 495        /*
 496        DPCT1009:200: SYCL uses exceptions to report errors and does not use the
 497        error codes. The original code was commented out and a warning string
 498        was inserted. You need to rewrite this code.
 499        */
 500
 501        queue_ptr stream_dst = dst_ctx->stream;
 502        queue_ptr stream_src = src_ctx->stream;
 503        size_t size = ggml_nbytes(src);
 504
 505        //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
 506        dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
 507
 508//todo, it's known issue๏ผšerror in device2device cross GPUs. reused when the issue is fixed. DON"T remove
 509#if 0
 510        SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
 511            (char *)dst->data, (const char *)src->data, size).wait()));
 512
 513        /*
 514        DPCT1009:201: SYCL uses exceptions to report errors and does not use the
 515        error codes. The original code was commented out and a warning string
 516        was inserted. You need to rewrite this code.
 517        */
 518        SYCL_CHECK(CHECK_TRY_ERROR(
 519            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
 520#endif
 521        return true;
 522    }
 523    return false;
 524    GGML_UNUSED(buffer);
 525} catch (const sycl::exception & exc) {
 526    std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
 527    std::exit(1);
 528}
 529
 530static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
 531                                           uint8_t value) try {
 532    GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
 533    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
 534
 535    ggml_sycl_set_device(ctx->device);
 536    queue_ptr stream = ctx->stream;
 537    SYCL_CHECK(
 538        CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
 539
 540    SYCL_CHECK(CHECK_TRY_ERROR((*stream)
 541                                    .memset(ctx->dev_ptr, value, buffer->size)
 542                                    .wait()));
 543}
 544catch (sycl::exception const &exc) {
 545  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 546            << ", line:" << __LINE__ << std::endl;
 547  std::exit(1);
 548}
 549
 550static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
 551                                                   size_t offset, size_t size) {
 552    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 553    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
 554    GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
 555    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
 556    SYCL_CHECK(ggml_sycl_set_device(ctx->device));
 557    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
 558    if (size == 0) {
 559        return;  // Nothing to do
 560    }
 561    if (tensor->data == nullptr) {
 562        GGML_ABORT("Error: Tensor data pointer is null.\n");
 563    }
 564    void * target_ptr = static_cast<char *>(tensor->data) + offset;
 565    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
 566    SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
 567}
 568
 569static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
 570    GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
 571    if (buffer == nullptr) {
 572        return;
 573    }
 574
 575    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
 576
 577    if (ctx != nullptr) {
 578        for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
 579            release_extra_gpu(extra);
 580        }
 581        ctx->tensor_extras.clear();  // reset the tensor_extras vector
 582    }
 583}
 584
 585static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
 586    /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
 587    /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
 588    /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
 589    /* .memset_tensor   = */ ggml_backend_sycl_buffer_memset_tensor,
 590    /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
 591    /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
 592    /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
 593    /* .clear           = */ ggml_backend_sycl_buffer_clear,
 594    /* .reset           = */ ggml_backend_sycl_buffer_reset,
 595};
 596
 597// sycl buffer type
 598struct ggml_backend_sycl_buffer_type_context {
 599    int device;
 600    std::string name;
 601
 602    // each buffer type has its own stream
 603    queue_ptr stream = nullptr;
 604};
 605
 606static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
 607    ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
 608
 609    return ctx->name.c_str();
 610}
 611
 612static ggml_backend_buffer_t
 613ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
 614                                           size_t size) try {
 615    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
 616    ggml_sycl_set_device(buft_ctx->device);
 617    const queue_ptr stream = buft_ctx->stream;
 618    size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
 619
 620    void * dev_ptr;
 621    SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
 622                                    size, *stream)));
 623    if (!dev_ptr) {
 624      GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
 625      return nullptr;
 626    }
 627    ggml_backend_sycl_buffer_context * ctx = new  ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
 628    return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
 629}
 630catch (sycl::exception const &exc) {
 631  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 632            << ", line:" << __LINE__ << std::endl;
 633  std::exit(1);
 634}
 635
 636static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
 637    return 128;
 638    GGML_UNUSED(buft);
 639}
 640
 641static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
 642    return dpct::get_current_device().get_max_mem_alloc_size();
 643
 644    GGML_UNUSED(buft);
 645}
 646
 647static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
 648    size_t size = ggml_nbytes(tensor);
 649    int64_t ne0 = tensor->ne[0];
 650
 651    if (ggml_is_quantized(tensor->type)) {
 652        if (ne0 % MATRIX_ROW_PADDING != 0) {
 653            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 654        }
 655    }
 656
 657    return size;
 658
 659    GGML_UNUSED(buft);
 660}
 661
 662static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
 663    /* .get_name         = */ ggml_backend_sycl_buffer_type_get_name,
 664    /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,
 665    /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,
 666    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,
 667    /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,
 668    /* .is_host          = */ NULL,
 669};
 670
 671ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
 672    static std::mutex mutex;
 673    std::lock_guard<std::mutex> lock(mutex);
 674
 675
 676    auto dev_count = ggml_backend_sycl_get_device_count();
 677
 678    if (device>=dev_count or device<0) {
 679        GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
 680            device, dev_count-1);
 681        GGML_ASSERT(device<dev_count);
 682    }
 683    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
 684
 685    static bool ggml_backend_sycl_buffer_type_initialized = false;
 686
 687    if (!ggml_backend_sycl_buffer_type_initialized) {
 688        for (int i = 0; i < dev_count; i++) {
 689            auto & device_i = dpct::dev_mgr::instance().get_device(i);
 690            queue_ptr stream = &(device_i.default_queue());
 691            ggml_backend_sycl_buffer_types[i] = {
 692                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
 693                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
 694                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
 695            };
 696        }
 697        ggml_backend_sycl_buffer_type_initialized = true;
 698    }
 699    return &ggml_backend_sycl_buffer_types[device];
 700}
 701
 702static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
 703    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
 704
 705    int device = ctx->device;
 706    if (device>=ggml_sycl_info().device_count or device<0) {
 707        GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
 708            device, ggml_sycl_info().device_count-1);
 709        GGML_ASSERT(device<ggml_sycl_info().device_count);
 710    }
 711    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
 712
 713    static bool ggml_backend_sycl_buffer_type_initialized = false;
 714
 715    if (!ggml_backend_sycl_buffer_type_initialized) {
 716        for (int i = 0; i < ggml_sycl_info().device_count; i++) {
 717            ggml_backend_sycl_buffer_types[i] = {
 718                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
 719                /* .device   = */ nullptr,
 720                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
 721            };
 722        }
 723        ggml_backend_sycl_buffer_type_initialized = true;
 724    }
 725    return &ggml_backend_sycl_buffer_types[device];
 726}
 727
 728// sycl split buffer
 729
 730static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
 731    int64_t min_compute_capability = INT_MAX;
 732    int64_t max_compute_capability = INT_MIN;
 733    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
 734        if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
 735            if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
 736                min_compute_capability = ggml_sycl_info().devices[i].cc;
 737            }
 738            if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
 739                max_compute_capability = ggml_sycl_info().devices[i].cc;
 740            }
 741        }
 742    }
 743
 744    switch(type) {
 745        case GGML_TYPE_Q4_0:
 746        case GGML_TYPE_Q4_1:
 747            return max_compute_capability >= VER_GEN9 ? 128 : 64;
 748        case GGML_TYPE_Q5_0:
 749        case GGML_TYPE_Q5_1:
 750        case GGML_TYPE_Q8_0:
 751            return 64;
 752        case GGML_TYPE_F16:
 753        case GGML_TYPE_F32:
 754            return 1;
 755        case GGML_TYPE_Q2_K:
 756        case GGML_TYPE_Q3_K:
 757        case GGML_TYPE_Q4_K:
 758        case GGML_TYPE_Q5_K:
 759        case GGML_TYPE_IQ2_XXS:
 760        case GGML_TYPE_IQ2_XS:
 761        case GGML_TYPE_IQ2_S:
 762        case GGML_TYPE_IQ1_S:
 763        case GGML_TYPE_IQ1_M:
 764        case GGML_TYPE_IQ3_XXS:
 765        case GGML_TYPE_IQ4_XS:
 766        case GGML_TYPE_IQ4_NL:
 767            return max_compute_capability >= VER_GEN9 ? 128 : 64;
 768        case GGML_TYPE_IQ3_S:
 769            return max_compute_capability >= VER_GEN9 ? 128 : 64;
 770        case GGML_TYPE_Q6_K:
 771            return 64;
 772        default:
 773            GGML_ABORT("fatal error");
 774    }
 775}
 776
 777static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
 778    const int64_t nrows = ggml_nrows(tensor);
 779    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
 780
 781    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
 782    *row_low -= *row_low % rounding;
 783    if (id == ggml_sycl_info().device_count - 1) {
 784        *row_high = nrows;
 785    } else {
 786        *row_high = nrows*tensor_split[id + 1];
 787        *row_high -= *row_high % rounding;
 788    }
 789}
 790
 791static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
 792    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 793
 794    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
 795}
 796
 797struct ggml_backend_sycl_split_buffer_type_context {
 798    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
 799};
 800
 801struct ggml_backend_sycl_split_buffer_context {
 802    ~ggml_backend_sycl_split_buffer_context() try {
 803        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
 804            release_extra_gpu(extra, streams);
 805        }
 806    }
 807    catch (sycl::exception const &exc) {
 808      std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 809                << ", line:" << __LINE__ << std::endl;
 810      std::exit(1);
 811    }
 812
 813    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
 814    std::vector<queue_ptr> streams;
 815};
 816
 817static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 818    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
 819    delete ctx;
 820}
 821
 822static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
 823    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
 824    return (void *)0x1000;
 825
 826    GGML_UNUSED(buffer);
 827}
 828
 829static enum ggml_status
 830ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
 831                                           ggml_tensor *tensor) try {
 832    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 833    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
 834    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
 835
 836    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
 837    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 838
 839    const int64_t ne0 = tensor->ne[0];
 840
 841    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
 842
 843    ctx->tensor_extras.push_back(extra);
 844    ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
 845
 846    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
 847        int64_t row_low, row_high;
 848        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
 849
 850        int64_t nrows_split = row_high - row_low;
 851        if (nrows_split == 0) {
 852            continue;
 853        }
 854
 855        size_t size = ggml_nbytes_split(tensor, nrows_split);
 856        const size_t original_size = size;
 857
 858        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
 859        if (ne0 % MATRIX_ROW_PADDING != 0) {
 860            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 861        }
 862
 863        // FIXME: do not crash if SYCL Buffer alloc fails
 864        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
 865        ggml_sycl_set_device(i);
 866        const queue_ptr stream = ctx->streams[i];
 867        char * buf;
 868        /*
 869        DPCT1009:208: SYCL uses exceptions to report errors and does not use the
 870        error codes. The original code was commented out and a warning string
 871        was inserted. You need to rewrite this code.
 872        */
 873        SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
 874                                        size, *stream)));
 875        if (!buf) {
 876            char err_buf[1024];
 877            snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
 878            throw std::runtime_error(err_buf);
 879        }
 880        // set padding to 0 to avoid possible NaN values
 881        if (size > original_size) {
 882            /*
 883            DPCT1009:209: SYCL uses exceptions to report errors and does not use
 884            the error codes. The original code was commented out and a warning
 885            string was inserted. You need to rewrite this code.
 886            */
 887            SYCL_CHECK(CHECK_TRY_ERROR(
 888                (*stream)
 889                    .memset(buf + original_size, 0, size - original_size)
 890                    .wait()));
 891        }
 892
 893        extra->data_device[i] = buf;
 894
 895        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
 896            /*
 897            DPCT1009:210: SYCL uses exceptions to report errors and does not use
 898            the error codes. The original code was commented out and a warning
 899            string was inserted. You need to rewrite this code.
 900            */
 901            SYCL_CHECK(
 902                CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
 903        }
 904    }
 905    tensor->extra = extra;
 906    return GGML_STATUS_SUCCESS;
 907}
 908catch (sycl::exception const &exc) {
 909  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 910            << ", line:" << __LINE__ << std::endl;
 911  std::exit(1);
 912}
 913
 914static void
 915ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
 916                                          ggml_tensor *tensor, const void *data,
 917                                          size_t offset, size_t size) try {
 918    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 919    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
 920    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
 921    // split tensors must always be set in their entirety at once
 922    GGML_ASSERT(offset == 0);
 923    GGML_ASSERT(size == ggml_nbytes(tensor));
 924
 925    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
 926    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 927
 928    const int64_t ne0 = tensor->ne[0];
 929    const size_t nb1 = tensor->nb[1];
 930    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 931
 932    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
 933        int64_t row_low, row_high;
 934        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
 935
 936        int64_t nrows_split = row_high - row_low;
 937        if (nrows_split == 0) {
 938            continue;
 939        }
 940
 941        const size_t offset_split = row_low*nb1;
 942        size_t size = ggml_nbytes_split(tensor, nrows_split);
 943        const size_t original_size = size;
 944
 945        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
 946        if (ne0 % MATRIX_ROW_PADDING != 0) {
 947            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 948        }
 949
 950        const char * buf_host = (const char *)data + offset_split;
 951        /*
 952        DPCT1009:211: SYCL uses exceptions to report errors and does not use the
 953        error codes. The original code was commented out and a warning string
 954        was inserted. You need to rewrite this code.
 955        */
 956        ggml_sycl_set_device(i);
 957        const queue_ptr stream = ctx->streams[i];
 958        SYCL_CHECK(CHECK_TRY_ERROR(
 959            (*stream)
 960                .memcpy(extra->data_device[i], buf_host, original_size)
 961                .wait()));
 962    }
 963}
 964catch (sycl::exception const &exc) {
 965  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 966            << ", line:" << __LINE__ << std::endl;
 967  std::exit(1);
 968}
 969
 970static void
 971ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
 972                                          const ggml_tensor *tensor, void *data,
 973                                          size_t offset, size_t size) try {
 974    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
 975    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
 976    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
 977    // split tensors must always be set in their entirety at once
 978    GGML_ASSERT(offset == 0);
 979    GGML_ASSERT(size == ggml_nbytes(tensor));
 980
 981    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
 982    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 983
 984    const int64_t ne0 = tensor->ne[0];
 985    const size_t nb1 = tensor->nb[1];
 986    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 987
 988    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
 989        int64_t row_low, row_high;
 990        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
 991
 992        int64_t nrows_split = row_high - row_low;
 993        if (nrows_split == 0) {
 994            continue;
 995        }
 996
 997        const size_t offset_split = row_low*nb1;
 998        size_t size = ggml_nbytes_split(tensor, nrows_split);
 999        const size_t original_size = size;
1000
1001        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
1002        if (ne0 % MATRIX_ROW_PADDING != 0) {
1003            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1004        }
1005
1006        char * buf_host = (char *)data + offset_split;
1007        /*
1008        DPCT1009:212: SYCL uses exceptions to report errors and does not use the
1009        error codes. The original code was commented out and a warning string
1010        was inserted. You need to rewrite this code.
1011        */
1012        ggml_sycl_set_device(i);
1013        const queue_ptr stream = ctx->streams[i];
1014        SYCL_CHECK(CHECK_TRY_ERROR(
1015            (*stream)
1016                .memcpy(buf_host, extra->data_device[i], original_size)
1017                .wait()));
1018    }
1019}
1020catch (sycl::exception const &exc) {
1021  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1022            << ", line:" << __LINE__ << std::endl;
1023  std::exit(1);
1024}
1025
1026static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1027    GGML_UNUSED(buffer);
1028    GGML_UNUSED(value);
1029}
1030
1031static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
1032    /* .free_buffer     = */ ggml_backend_sycl_split_buffer_free_buffer,
1033    /* .get_base        = */ ggml_backend_sycl_split_buffer_get_base,
1034    /* .init_tensor     = */ ggml_backend_sycl_split_buffer_init_tensor,
1035    /* .memset_tensor   = */ NULL,
1036    /* .set_tensor      = */ ggml_backend_sycl_split_buffer_set_tensor,
1037    /* .get_tensor      = */ ggml_backend_sycl_split_buffer_get_tensor,
1038    /* .cpy_tensor      = */ NULL,
1039    /* .clear           = */ ggml_backend_sycl_split_buffer_clear,
1040    /* .reset           = */ NULL,
1041};
1042
1043// sycl split buffer type
1044
1045static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1046    return GGML_SYCL_NAME "_Split";
1047
1048    GGML_UNUSED(buft);
1049}
1050
1051static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
1052   return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;
1053}
1054
1055static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1056    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
1057    // instead, we allocate them for each tensor separately in init_tensor
1058    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
1059    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
1060    ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
1061
1062    return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
1063}
1064
1065static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1066    return 128;
1067    GGML_UNUSED(buft);
1068}
1069
1070static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
1071    ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
1072
1073    size_t total_size = 0;
1074
1075    const int64_t ne0 = tensor->ne[0];
1076
1077    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1078        int64_t row_low, row_high;
1079        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
1080
1081        int64_t nrows_split = row_high - row_low;
1082        if (nrows_split == 0) {
1083            continue;
1084        }
1085
1086        total_size += ggml_nbytes_split(tensor, nrows_split);
1087
1088        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
1089        if (ne0 % MATRIX_ROW_PADDING != 0) {
1090            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1091        }
1092    }
1093
1094    return total_size;
1095}
1096
1097static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1098    return false;
1099
1100    GGML_UNUSED(buft);
1101}
1102
1103static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
1104    /* .get_name         = */ ggml_backend_sycl_split_buffer_type_get_name,
1105    /* .alloc_buffer     = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
1106    /* .get_alignment    = */ ggml_backend_sycl_split_buffer_type_get_alignment,
1107    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
1108    /* .get_alloc_size   = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
1109    /* .is_host          = */ ggml_backend_sycl_split_buffer_type_is_host,
1110};
1111
1112ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
1113    static std::mutex mutex;
1114    std::lock_guard<std::mutex> lock(mutex);
1115
1116    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
1117    ggml_check_sycl();
1118    // FIXME: this is not thread safe
1119    static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
1120
1121    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
1122
1123    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
1124    if (all_zero) {
1125        tensor_split_arr = ggml_sycl_info().default_tensor_split;
1126    } else {
1127        float split_sum = 0.0f;
1128        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1129            tensor_split_arr[i] = split_sum;
1130            split_sum += tensor_split[i];
1131        }
1132        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1133            tensor_split_arr[i] /= split_sum;
1134        }
1135    }
1136
1137    auto it = buft_map.find(tensor_split_arr);
1138    if (it != buft_map.end()) {
1139        return &it->second;
1140    }
1141
1142    struct ggml_backend_buffer_type buft {
1143        /* .iface   = */ ggml_backend_sycl_split_buffer_type_interface,
1144        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
1145        /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
1146    };
1147
1148    auto result = buft_map.emplace(tensor_split_arr, buft);
1149    return &result.first->second;
1150}
1151
1152// host buffer type
1153
1154static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1155    return GGML_SYCL_NAME "_Host";
1156
1157    GGML_UNUSED(buft);
1158}
1159
1160inline void * aligned_malloc_host(size_t alignment, size_t size) {
1161#ifdef _WIN32
1162    return _aligned_malloc(size, alignment);
1163#else
1164    return aligned_alloc(alignment, size);
1165#endif
1166}
1167
1168inline void free_aligned_mem_host(void * memblock) {
1169#ifdef _WIN32
1170    _aligned_free(memblock);
1171#else
1172    free(memblock);
1173#endif
1174}
1175
1176static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1177    free_aligned_mem_host((void *)buffer->context);
1178}
1179
1180static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1181    void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
1182    if (ptr == nullptr) {
1183        // fallback to cpu buffer
1184        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1185    }
1186
1187    // FIXME: this is a hack to avoid having to implement a new buffer type
1188    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
1189    buffer->buft = buft;
1190    buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
1191
1192    return buffer;
1193}
1194
1195ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
1196    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
1197    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
1198        /* .iface    = */ {
1199            /* .get_name         = */ ggml_backend_sycl_host_buffer_type_name,
1200            /* .alloc_buffer     = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
1201            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1202            /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength
1203            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1204            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1205        },
1206        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
1207        /* .context  = */ nullptr,
1208    };
1209
1210    return &ggml_backend_sycl_buffer_type_host;
1211}
1212
1213// buffer pool for sycl (legacy)
1214struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1215    static const int MAX_SYCL_BUFFERS = 256;
1216
1217    int device;
1218    queue_ptr qptr;
1219    struct ggml_sycl_buffer {
1220        void * ptr = nullptr;
1221        size_t size = 0;
1222    };
1223
1224    ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
1225    size_t pool_size = 0;
1226
1227    explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}
1228
1229    ~ggml_sycl_pool_leg() {
1230        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1231            ggml_sycl_buffer & b = buffer_pool[i];
1232            if (b.ptr != nullptr) {
1233                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1234                pool_size -= b.size;
1235            }
1236        }
1237        GGML_ASSERT(pool_size == 0);
1238    }
1239
1240    void * alloc(size_t size, size_t * actual_size) override {
1241#ifdef DEBUG_sycl_MALLOC
1242        int nnz = 0;
1243        size_t max_size = 0;
1244#endif
1245        size_t best_diff = 1ull << 36;
1246        int ibest = -1;
1247        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1248            ggml_sycl_buffer& b = buffer_pool[i];
1249            if (b.ptr != nullptr) {
1250#ifdef DEBUG_sycl_MALLOC
1251                ++nnz;
1252                if (b.size > max_size) max_size = b.size;
1253#endif
1254                if (b.size >= size) {
1255                    size_t diff = b.size - size;
1256                    if (diff < best_diff) {
1257                        best_diff = diff;
1258                        ibest = i;
1259                        if (!best_diff) {
1260                            void * ptr = b.ptr;
1261                            *actual_size = b.size;
1262                            b.ptr = nullptr;
1263                            b.size = 0;
1264                            return ptr;
1265                        }
1266                    }
1267                }
1268            }
1269        }
1270        if (ibest >= 0) {
1271            ggml_sycl_buffer& b = buffer_pool[ibest];
1272            void * ptr = b.ptr;
1273            *actual_size = b.size;
1274            b.ptr = nullptr;
1275            b.size = 0;
1276            return ptr;
1277        }
1278        void * ptr;
1279        size_t look_ahead_size = (size_t) (1.05 * size);
1280
1281        SYCL_CHECK(
1282            CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
1283                                look_ahead_size, *qptr)));
1284        if (!ptr) {
1285            GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size);
1286            return nullptr;
1287        }
1288
1289        *actual_size = look_ahead_size;
1290        pool_size += look_ahead_size;
1291
1292#ifdef DEBUG_SYCL_MALLOC
1293        GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
1294                (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
1295#endif
1296
1297        // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
1298        return ptr;
1299    }
1300
1301    void free(void * ptr, size_t size) override {
1302        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1303            ggml_sycl_buffer& b = buffer_pool[i];
1304            if (b.ptr == nullptr) {
1305                b.ptr = ptr;
1306                b.size = size;
1307                return;
1308            }
1309        }
1310        GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
1311        SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
1312        pool_size -= size;
1313    }
1314};
1315
1316struct ggml_sycl_pool_host : public ggml_sycl_pool {
1317    queue_ptr qptr;
1318    int       device;
1319
1320    inline static int counter{ 0 };
1321
1322    struct ggml_sycl_buffer {
1323        void * ptr  = nullptr;
1324        size_t size = 0;
1325    };
1326
1327    // Set arbitrarly to 64
1328    static constexpr int          MAX_POOL_SIZE{ 64 };
1329    std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1330    size_t                        pool_size   = 0;
1331
1332    explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1333
1334    ~ggml_sycl_pool_host() {
1335        for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1336            ggml_sycl_buffer & b = buffer_pool[i];
1337            if (b.ptr != nullptr) {
1338                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1339                b.ptr = nullptr;
1340                pool_size -= b.size;
1341                b.size = 0;
1342            }
1343        }
1344        counter = 0;
1345    }
1346
1347    void * alloc(size_t size, size_t * actual_size) override {
1348        if (counter == MAX_POOL_SIZE) {
1349            ggml_sycl_buffer b               = buffer_pool[0];
1350            void *           ptr             = b.ptr;
1351            *actual_size                     = b.size;
1352            counter                          = 1;
1353            return ptr;
1354        }
1355        ggml_sycl_buffer & b = buffer_pool[counter];
1356
1357        if (b.ptr == nullptr) {
1358            void * ptr;
1359
1360            SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1361            if (!ptr) {
1362                GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1363                return nullptr;
1364            }
1365            pool_size += size;
1366            *actual_size = size;
1367            counter      = counter + 1;
1368            return ptr;
1369        } else {
1370            ++counter;
1371            b.size = size;
1372            return b.ptr;
1373        }
1374    }
1375
1376    void free(void * ptr, size_t size) override {
1377        // if the pool is not completed add the pointer to it in place of the first nullptr found.
1378        // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1379        for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1380            ggml_sycl_buffer & b = buffer_pool[i];
1381            if (b.ptr == nullptr) {
1382                b.ptr  = ptr;
1383                b.size = size;
1384                return;
1385            }
1386        }
1387    }
1388};
1389
1390std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1391    // return pool for the host to speed up memory management
1392    return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1393}
1394
1395std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1396    // TBD: NO VMM support
1397    // if (ggml_sycl_info().devices[device].vmm) {
1398    //     return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
1399    // }
1400   return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
1401}
1402
1403// TBD pool with virtual memory management
1404// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1405
1406/// kernels
1407typedef void (*ggml_sycl_op_mul_mat_t)(
1408    ggml_backend_sycl_context & ctx,
1409    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
1410    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
1411    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
1412    const int64_t src1_ncols, const int64_t src1_padded_row_size,
1413    const queue_ptr &stream);
1414
1415
1416
1417static void mul_mat_p021_f16_f32(
1418    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1419    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
1420    const sycl::nd_item<3> &item_ct1) {
1421
1422    const sycl::half *x = (const sycl::half *)vx;
1423
1424    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1425                      item_ct1.get_local_id(1);
1426    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1427                        item_ct1.get_local_id(0);
1428    const int channel_x = channel / (nchannels_y / nchannels_x);
1429
1430    const int nrows_y = ncols_x;
1431    const int nrows_dst = nrows_x;
1432    const int row_dst = row_x;
1433
1434    float tmp = 0.0f;
1435
1436    for (int col_x0 = 0; col_x0 < ncols_x;
1437         col_x0 += item_ct1.get_local_range(2)) {
1438        const int col_x = col_x0 + item_ct1.get_local_id(2);
1439
1440        if (col_x >= ncols_x) {
1441            break;
1442        }
1443
1444        // x is transposed and permuted
1445        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
1446        const float xi =
1447            sycl::vec<sycl::half, 1>(x[ix])
1448                .convert<float, sycl::rounding_mode::automatic>()[0];
1449
1450        const int row_y = col_x;
1451
1452
1453        // y is not transposed but permuted
1454        const int iy = channel*nrows_y + row_y;
1455
1456        tmp += xi * y[iy];
1457    }
1458
1459    // dst is not transposed and not permuted
1460    const int idst = channel*nrows_dst + row_dst;
1461
1462    // sum up partial sums and write back result
1463#pragma unroll
1464    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
1465        tmp +=
1466            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
1467    }
1468
1469    if (item_ct1.get_local_id(2) == 0) {
1470        dst[idst] = tmp;
1471    }
1472}
1473
1474static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1475    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1476    const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
1477    const sycl::nd_item<3> &item_ct1) {
1478
1479    const sycl::half *x = (const sycl::half *)vx;
1480
1481    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1482                      item_ct1.get_local_id(1);
1483    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1484                        item_ct1.get_local_id(0);
1485    const int channel_x = channel / channel_x_divisor;
1486
1487    const int nrows_dst = nrows_x;
1488    const int row_dst   = row_x;
1489
1490    const int idst = channel*nrows_dst + row_dst;
1491
1492    float tmp = 0.0f;
1493
1494    for (int col_x0 = 0; col_x0 < ncols_x;
1495         col_x0 += item_ct1.get_local_range(2)) {
1496        const int col_x = col_x0 + item_ct1.get_local_id(2);
1497
1498        if (col_x >= ncols_x) {
1499            break;
1500        }
1501
1502        const int row_y = col_x;
1503
1504        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1505        const int iy = channel * channel_stride_y + row_y;
1506
1507        const float xi =
1508            sycl::vec<sycl::half, 1>(x[ix])
1509                .convert<float, sycl::rounding_mode::automatic>()[0];
1510
1511        tmp += xi * y[iy];
1512    }
1513
1514    // sum up partial sums and write back result
1515#pragma unroll
1516    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
1517        tmp +=
1518            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
1519    }
1520
1521    if (item_ct1.get_local_id(2) == 0) {
1522        dst[idst] = tmp;
1523    }
1524}
1525
1526static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1527                           const sycl::nd_item<3> &item_ct1) {
1528    const int row = item_ct1.get_group(1);
1529    const int col = item_ct1.get_local_id(2);
1530
1531    float sum = 0.0f;
1532    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
1533        sum += x[row * ncols + i];
1534    }
1535
1536    sum = warp_reduce_sum(sum, item_ct1);
1537
1538    if (col == 0) {
1539        dst[row] = sum;
1540    }
1541}
1542
1543
1544template<typename T>
1545static inline void ggml_sycl_swap(T & a, T & b) {
1546    T tmp = a;
1547    a = b;
1548    b = tmp;
1549}
1550
1551template <ggml_sort_order order>
1552__dpct_inline__ static void
1553k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1554                  const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
1555                  uint8_t *dpct_local) {
1556    // bitonic sort
1557    int col_index =  item_ct1.get_local_id(2);
1558    int row = item_ct1.get_group(1);
1559
1560    for (int i = 0; i < tasks_per_thread; i++) {
1561        int col = col_index * tasks_per_thread + i;
1562        if (col >= ncols_pad) {
1563            return;
1564        }
1565    }
1566
1567    const float * x_row = x + row * ncols;
1568    auto dst_row = (int *)dpct_local;
1569
1570    // initialize indices
1571    for (int i=0;i<tasks_per_thread;i++){
1572        int col = col_index*tasks_per_thread+i;
1573        dst_row[col] = col;
1574    }
1575
1576    item_ct1.barrier(sycl::access::fence_space::local_space);
1577
1578    for (int k = 2; k <= ncols_pad; k *= 2) {
1579        for (int j = k / 2; j > 0; j /= 2) {
1580            for (int i = 0; i < tasks_per_thread; i++) {
1581                int col = col_index * tasks_per_thread + i;
1582                int ixj = col ^ j;
1583                if (ixj > col) {
1584                    if ((col & k) == 0) {
1585                        if (dst_row[col] >= ncols ||
1586                            (dst_row[ixj] < ncols &&
1587                             (order == GGML_SORT_ORDER_ASC
1588                                  ? x_row[dst_row[col]] > x_row[dst_row[ixj]]
1589                                  : x_row[dst_row[col]] <
1590                                        x_row[dst_row[ixj]]))) {
1591                            ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1592                        }
1593                    } else {
1594                        if (dst_row[ixj] >= ncols ||
1595                            (dst_row[col] < ncols &&
1596                             (order == GGML_SORT_ORDER_ASC
1597                                  ? x_row[dst_row[col]] < x_row[dst_row[ixj]]
1598                                  : x_row[dst_row[col]] >
1599                                        x_row[dst_row[ixj]]))) {
1600                            ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1601                        }
1602                    }
1603                }
1604                item_ct1.barrier(sycl::access::fence_space::local_space);
1605            }
1606        }
1607    }
1608
1609    // copy the result to dst without the padding
1610    for (int i = 0; i < tasks_per_thread; i++) {
1611        int col = col_index * tasks_per_thread + i;
1612        if (col < ncols) {
1613            dst[row * ncols + col] = dst_row[col];
1614        }
1615    }
1616}
1617
1618static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1619                              const sycl::nd_item<3> &item_ct1) {
1620    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1621                    item_ct1.get_local_id(1);
1622    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1623                    item_ct1.get_local_id(2);
1624
1625    if (col >= ncols) {
1626        return;
1627    }
1628
1629    const int i = row*ncols + col;
1630    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
1631    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
1632    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1633}
1634
1635static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
1636                      const sycl::nd_item<3> &item_ct1) {
1637    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1638                  item_ct1.get_local_id(2);
1639
1640    if (i >= k) {
1641        return;
1642    }
1643
1644    dst[i] = scale * x[i] + bias;
1645}
1646
1647
1648template <typename Ti, typename To>
1649static  void pool2d_nchw_kernel(
1650        const int ih, const int iw, const int oh, const int ow,
1651        const int kh, const int kw, const int sh, const int sw,
1652        const int ph, const int pw, const int parallel_elements,
1653        const Ti* src, To* dst, const enum ggml_op_pool op,
1654        const sycl::nd_item<3> &item_ct1) {
1655        int idx = item_ct1.get_local_id(2) +
1656                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
1657        if (idx >= parallel_elements) {
1658            return;
1659        }
1660
1661        const int I_HW = ih * iw;
1662        const int O_HW = oh * ow;
1663        const int nc = idx / O_HW;
1664        const int cur_oh = idx % O_HW / ow;
1665        const int cur_ow = idx % O_HW % ow;
1666        const Ti* i_ptr = src + nc * I_HW;
1667        To* o_ptr = dst + nc * O_HW;
1668        const int start_h = cur_oh * sh - ph;
1669        const int bh = sycl::max(0, start_h);
1670        const int eh = sycl::min(ih, start_h + kh);
1671        const int start_w = cur_ow * sw - pw;
1672        const int bw = sycl::max(0, start_w);
1673        const int ew = sycl::min(iw, start_w + kw);
1674
1675        To res = 0;
1676
1677        switch (op) {
1678            case GGML_OP_POOL_AVG: res = 0; break;
1679            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
1680            default:
1681                res      = (To) sycl::nan(uint32_t(0));
1682                break;
1683        }
1684
1685        for (int i = bh; i < eh; i += 1) {
1686            for (int j = bw; j < ew; j += 1) {
1687#if DPCT_COMPATIBILITY_TEMP >= 350
1688                /*
1689                DPCT1098:106: The '*' expression is used instead of the __ldg
1690                call. These two expressions do not provide the exact same
1691                functionality. Check the generated code for potential precision
1692                and/or performance issues.
1693                */
1694                Ti cur = *(i_ptr + i * iw + j);
1695#else
1696                Ti cur = i_ptr[i * iw + j];
1697#endif
1698                switch (op) {
1699                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
1700                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
1701                    default:
1702                        res = (To) sycl::nan(uint32_t(0));
1703                        break;
1704                }
1705            }
1706        }
1707        o_ptr[cur_oh * ow + cur_ow] = res;
1708}
1709
1710
1711static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1712                                           float *dst, const int ncols_x,
1713                                           const int nrows_x,
1714                                           const int nchannels_x,
1715                                           const int nchannels_y,
1716                                           queue_ptr stream) {
1717
1718    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1719    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1720    {
1721        dpct::has_capability_or_fail(stream->get_device(),
1722                                     {sycl::aspect::fp16});
1723
1724        stream->parallel_for(
1725            sycl::nd_range<3>(block_nums * block_dims, block_dims),
1726            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1727                mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
1728                                     nchannels_y, item_ct1);
1729            });
1730    }
1731}
1732
1733static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1734    const void *vx, const float *y, float *dst, const int ncols_x,
1735    const int nrows_x, const int row_stride_x, const int nchannels_x,
1736    const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
1737
1738    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1739    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1740    {
1741        dpct::has_capability_or_fail(stream->get_device(),
1742                                     {sycl::aspect::fp16});
1743
1744        stream->parallel_for(
1745            sycl::nd_range<3>(block_nums * block_dims, block_dims),
1746            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1747                mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1748                                       row_stride_x, channel_stride_x, channel_stride_y,
1749                                       nchannels_y / nchannels_x, item_ct1);
1750            });
1751    }
1752}
1753
1754
1755
1756static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
1757                           const int k, queue_ptr stream) {
1758    const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
1759    stream->parallel_for(
1760        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1761                              sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
1762                          sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
1763        [=](sycl::nd_item<3> item_ct1) {
1764            scale_f32(x, dst, scale, bias, k, item_ct1);
1765        });
1766}
1767
1768
1769static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1770                              const int nrows, queue_ptr stream) {
1771    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1772    const sycl::range<3> block_nums(1, nrows, 1);
1773    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1774                         [=](sycl::nd_item<3> item_ct1)
1775                             [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1776                                 k_sum_rows_f32(x, dst, ncols, item_ct1);
1777                             });
1778}
1779
1780static int next_power_of_2(int x) {
1781    int n = 1;
1782    while (n < x) {
1783        n *= 2;
1784    }
1785    return n;
1786}
1787
1788static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1789                                 const int nrows, ggml_sort_order order,
1790                                 queue_ptr stream, int device) {
1791    // bitonic sort requires ncols to be power of 2
1792    const int ncols_pad = next_power_of_2(ncols);
1793
1794    int nth = 1;
1795    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
1796    while (nth < ncols_pad && nth < max_block_size)
1797        nth *= 2;
1798    if (nth > max_block_size)
1799        nth = max_block_size;
1800
1801    const int tasks_per_thread = ncols_pad / nth;
1802
1803    const sycl::range<3> block_dims(1, 1, nth);
1804    const sycl::range<3> block_nums(1, nrows, 1);
1805    const size_t shared_mem = ncols_pad * sizeof(int);
1806    GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
1807
1808    if (order == GGML_SORT_ORDER_ASC) {
1809        stream->submit([&](sycl::handler &cgh) {
1810            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1811                sycl::range<1>(shared_mem), cgh);
1812
1813            cgh.parallel_for(
1814                sycl::nd_range<3>(block_nums * block_dims, block_dims),
1815                [=](sycl::nd_item<3> item_ct1) {
1816                    k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1817                        x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1818                        dpct_local_acc_ct1
1819                            .get_multi_ptr<sycl::access::decorated::no>()
1820                            .get());
1821                });
1822        });
1823    } else if (order == GGML_SORT_ORDER_DESC) {
1824        stream->submit([&](sycl::handler &cgh) {
1825            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1826                sycl::range<1>(shared_mem), cgh);
1827
1828            cgh.parallel_for(
1829                sycl::nd_range<3>(block_nums * block_dims, block_dims),
1830                [=](sycl::nd_item<3> item_ct1) {
1831                    k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1832                        x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1833                        dpct_local_acc_ct1
1834                            .get_multi_ptr<sycl::access::decorated::no>()
1835                            .get());
1836                });
1837        });
1838    } else {
1839        GGML_ABORT("fatal error");
1840    }
1841}
1842
1843static void top_k_f32_sycl(
1844    const float * src,
1845    int32_t * dst_indices,
1846    const int64_t ncols,
1847    const int64_t nrows,
1848    const int k,
1849    dpct::queue_ptr main_stream
1850) {
1851    const int block_size = 128;
1852
1853    const sycl::range<1> block_dims(block_size);
1854    const sycl::range<1> grid_dims(nrows);
1855
1856    main_stream->submit([&](sycl::handler &cgh) {
1857        sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
1858        sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
1859
1860        cgh.parallel_for(
1861            sycl::nd_range<1>(grid_dims * block_dims, block_dims),
1862            [=](sycl::nd_item<1> item_ct1) {
1863                const int row = item_ct1.get_group(0);
1864                const int tid = item_ct1.get_local_id(0);
1865
1866                if (row >= nrows) return;
1867
1868                const float * src_row = src + row * ncols;
1869                int32_t * dst_idx_row = dst_indices + row * k;
1870
1871                float local_vals[32];
1872                int local_idx[32];
1873
1874                for (int i = 0; i < k; i++) {
1875                    local_vals[i] = -FLT_MAX;
1876                    local_idx[i] = -1;
1877                }
1878
1879                for (int col = tid; col < ncols; col += block_size) {
1880                    float val = src_row[col];
1881
1882                    if (val > local_vals[k-1]) {
1883                        int pos = k - 1;
1884                        while (pos > 0 && val > local_vals[pos - 1]) {
1885                            pos--;
1886                        }
1887
1888                        for (int i = k - 1; i > pos; i--) {
1889                            local_vals[i] = local_vals[i - 1];
1890                            local_idx[i] = local_idx[i - 1];
1891                        }
1892                        local_vals[pos] = val;
1893                        local_idx[pos] = col;
1894                    }
1895                }
1896
1897                for (int i = 0; i < k; i++) {
1898                    shared_vals[tid * k + i] = local_vals[i];
1899                    shared_idx[tid * k + i] = local_idx[i];
1900                }
1901                item_ct1.barrier(sycl::access::fence_space::local_space);
1902
1903                if (tid == 0) {
1904                    float final_vals[32];
1905                    int final_idx[32];
1906
1907                    for (int i = 0; i < k; i++) {
1908                        final_vals[i] = -FLT_MAX;
1909                        final_idx[i] = -1;
1910                    }
1911
1912                    for (int t = 0; t < block_size; t++) {
1913                        for (int i = 0; i < k; i++) {
1914                            float val = shared_vals[t * k + i];
1915                            int idx = shared_idx[t * k + i];
1916
1917                            if (val > final_vals[k-1]) {
1918                                int pos = k - 1;
1919                                while (pos > 0 && val > final_vals[pos - 1]) {
1920                                    pos--;
1921                                }
1922
1923                                for (int j = k - 1; j > pos; j--) {
1924                                    final_vals[j] = final_vals[j - 1];
1925                                    final_idx[j] = final_idx[j - 1];
1926                                }
1927                                final_vals[pos] = val;
1928                                final_idx[pos] = idx;
1929                            }
1930                        }
1931                    }
1932
1933                    for (int i = 0; i < k; i++) {
1934                        dst_idx_row[i] = final_idx[i];
1935                    }
1936
1937                    if (k > 1) {
1938                        int32_t temp = dst_idx_row[0];
1939                        dst_idx_row[0] = dst_idx_row[1];
1940                        dst_idx_row[1] = temp;
1941                    }
1942                }
1943            });
1944    });
1945}
1946
1947static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1948                               const int nrows, queue_ptr stream) {
1949    const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
1950    const sycl::range<3> block_nums(1, nrows, 1);
1951    const size_t shared_mem = 256 * sizeof(float);
1952
1953    stream->submit([&](sycl::handler &cgh) {
1954        sycl::local_accessor<float, 1> shared_data(
1955            sycl::range<1>(shared_mem/sizeof(float)), cgh);
1956        sycl::local_accessor<int, 1> shared_indices(
1957            sycl::range<1>(shared_mem/sizeof(float)), cgh);
1958
1959        cgh.parallel_for(
1960            sycl::nd_range<3>(block_nums * block_dims, block_dims),
1961            [=](sycl::nd_item<3> item_ct1) {
1962                const int tid = item_ct1.get_local_id(2);
1963                const int row = item_ct1.get_global_id(1);
1964
1965                float max_val = -INFINITY;
1966                int max_idx = -1;
1967
1968                for (int col = tid; col < ncols; col += 256) {
1969                    float val = x[row * ncols + col];
1970                    if (val > max_val) {
1971                        max_val = val;
1972                        max_idx = col;
1973                    }
1974                }
1975
1976                shared_data[tid] = max_val;
1977                shared_indices[tid] = max_idx;
1978                item_ct1.barrier(sycl::access::fence_space::local_space);
1979
1980                for (int stride = 256/2; stride > 0; stride >>= 1) {
1981                    if (tid < stride) {
1982                        float val1 = shared_data[tid];
1983                        float val2 = shared_data[tid + stride];
1984                        if (val2 > val1) {
1985                            shared_data[tid] = val2;
1986                            shared_indices[tid] = shared_indices[tid + stride];
1987                        }
1988                    }
1989                    item_ct1.barrier(sycl::access::fence_space::local_space);
1990                }
1991
1992
1993                if (tid == 0) {
1994                    dst[row] = shared_indices[0];
1995                }
1996            });
1997    });
1998}
1999static void diag_mask_inf_f32_sycl(const float *x, float *dst,
2000                                   const int ncols_x, const int nrows_x,
2001                                   const int rows_per_channel, const int n_past,
2002                                   queue_ptr stream) {
2003    const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
2004    const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
2005    const sycl::range<3> block_nums(1, block_num_x, nrows_x);
2006    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
2007                         [=](sycl::nd_item<3> item_ct1) {
2008                             diag_mask_inf_f32(x, dst, ncols_x,
2009                                               rows_per_channel, n_past,
2010                                               item_ct1);
2011                         });
2012}
2013
2014static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2015                                          const struct ggml_tensor *src,
2016                                          int64_t i3, int64_t i2,
2017                                          int64_t i1_low, int64_t i1_high,
2018                                          queue_ptr stream) try {
2019
2020    dpct::memcpy_direction kind;
2021    char * src_ptr;
2022    if (ggml_backend_buffer_is_host(src->buffer)) {
2023        kind = dpct::host_to_device;
2024        //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
2025        src_ptr = (char *) src->data;
2026        // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d  GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
2027    } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
2028        // If buffer is a SYCL buffer
2029        //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
2030        kind    = dpct::device_to_device;
2031        src_ptr = (char *) src->data;
2032    } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
2033        /*
2034        If buffer is a SYCL split buffer
2035        */
2036        //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
2037        GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
2038        kind = dpct::device_to_device;
2039        ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
2040        int id;
2041        SYCL_CHECK(CHECK_TRY_ERROR(
2042            id = get_current_device_id()));
2043        // GGML_SYCL_DEBUG("current device index %d\n", id);
2044        src_ptr = (char *) extra->data_device[id];
2045    } else {
2046        // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
2047        GGML_ABORT("fatal error");
2048    }
2049    char * dst_ptr = (char *) dst;
2050
2051    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
2052    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
2053    const enum ggml_type type = src->type;
2054    const int64_t ts = ggml_type_size(type);
2055    const int64_t bs = ggml_blck_size(type);
2056    int64_t i1_diff = i1_high - i1_low;
2057
2058    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
2059    if (nb0 == ts && nb1 == ts*ne0/bs) {
2060        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
2061        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
2062        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
2063                                    kind, *stream));
2064
2065    } else if (nb0 == ts) {
2066        return CHECK_TRY_ERROR(
2067            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
2068                                    ts * ne0 / bs, i1_diff, kind, *stream));
2069    } else {
2070        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
2071            const void * rx = (const void *) ((const char *) x + i1*nb1);
2072            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
2073            // pretend the row is a matrix with cols=1
2074            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2075                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
2076            /*
2077            DPCT1001:85: The statement could not be removed.
2078            */
2079            /*
2080            DPCT1000:86: Error handling if-stmt was detected but could not be
2081            rewritten.
2082            */
2083            if (r != 0) return r;
2084        }
2085        return 0;
2086    }
2087}
2088catch (sycl::exception const &exc) {
2089  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2090            << ", line:" << __LINE__ << std::endl;
2091  std::exit(1);
2092}
2093
2094inline void ggml_sycl_op_mul_mat_sycl(
2095    ggml_backend_sycl_context & ctx,
2096    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2097    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2098    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2099    const int64_t src1_ncols, const int64_t src1_padded_row_size,
2100    const queue_ptr &stream) try {
2101
2102    GGML_ASSERT(src0_dd_i  != nullptr);
2103    GGML_ASSERT(src1_ddf_i != nullptr);
2104    GGML_ASSERT(dst_dd_i   != nullptr);
2105
2106    const int64_t ne00 = src0->ne[0];
2107    const int64_t ne10 = src1->ne[0];
2108    GGML_ASSERT(ne00 == ne10);
2109
2110    const int64_t row_diff = row_high - row_low;
2111
2112    int id;
2113    SYCL_CHECK(
2114        CHECK_TRY_ERROR(id = get_current_device_id()));
2115
2116    const int64_t ne0 = dst->ne[0]; // used by MKL only
2117    // the main device has a larger memory buffer to hold the results from all GPUs
2118    // ldc == nrows of the matrix that cuBLAS writes into
2119    int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
2120
2121#ifdef GGML_SYCL_F16
2122    bool use_fp16 = true;  // TODO(Yu) SYCL capability check
2123#else
2124    bool use_fp16 = false;
2125#endif
2126    if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
2127        row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
2128        ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2129        if (src0->type != GGML_TYPE_F16) {
2130            scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2131                                                 " : converting src0 to fp16");
2132            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2133            GGML_ASSERT(to_fp16_sycl != nullptr);
2134            size_t ne = row_diff*ne00;
2135            src0_as_f16.alloc(ne);
2136            to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
2137        }
2138        const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
2139                                         ? (const sycl::half *)src0_dd_i
2140                                         : src0_as_f16.get();
2141
2142        ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2143        if (src1->type != GGML_TYPE_F16) {
2144            scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2145                                                 " : converting src1 to fp16");
2146            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2147            GGML_ASSERT(to_fp16_sycl != nullptr);
2148            size_t ne = src1_ncols*ne10;
2149            src1_as_f16.alloc(ne);
2150            to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
2151        }
2152        const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2153                ? (const sycl::half *)src1->data + src1_padded_row_size
2154                                         : src1_as_f16.get();
2155
2156#if GGML_SYCL_DNNL
2157        if (!g_ggml_sycl_disable_dnn) {
2158                DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
2159                                     DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2160                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2161        }
2162        else
2163#endif
2164        {
2165            ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2166
2167            const sycl::half alpha_f16 = 1.0f;
2168            const sycl::half beta_f16  = 0.0f;
2169            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2170                *stream, oneapi::mkl::transpose::trans,
2171                oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2172                &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2173                src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2174                dst_f16.get(), dpct::library_data_t::real_half, ldc,
2175                dpct::library_data_t::real_half)));
2176            scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2177                                                 " : converting dst to fp32");
2178            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2179            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2180        }
2181    } else {
2182        ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2183        ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2184        if (src0->type != GGML_TYPE_F32) {
2185            scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2186                                                 " : converting src0 to fp32");
2187            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2188            GGML_ASSERT(to_fp32_sycl != nullptr);
2189            src0_ddq_as_f32.alloc(row_diff*ne00);
2190            to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2191        }
2192        if (src1->type != GGML_TYPE_F32) {
2193            scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2194                                                 " : converting src1 to fp32");
2195            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2196            GGML_ASSERT(to_fp32_sycl != nullptr);
2197            src1_ddq_as_f32.alloc(src1_ncols*ne10);
2198            to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
2199        }
2200        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2201        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2202
2203#if GGML_SYCL_DNNL
2204        if (!g_ggml_sycl_disable_dnn) {
2205            DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2206                                      DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2207                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2208        }
2209        else
2210#endif
2211        {
2212            const float alpha = 1.0f;
2213            const float beta  = 0.0f;
2214            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2215                *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
2216                src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2217                dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2218        }
2219    }
2220    GGML_UNUSED(dst);
2221    GGML_UNUSED(src1_ddq_i);
2222    GGML_UNUSED(src1_padded_row_size);
2223}
2224catch (sycl::exception const &exc) {
2225  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2226            << ", line:" << __LINE__ << std::endl;
2227  std::exit(1);
2228}
2229
2230static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2231    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2232    GGML_ASSERT( dst->type == GGML_TYPE_F32);
2233    dpct::queue_ptr main_stream = ctx.stream();
2234    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2235    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2236    float *       dst_dd  = static_cast<float *>(dst->data);
2237
2238    const int32_t * opts = (const int32_t *)dst->op_params;
2239    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
2240    const int k0 = opts[1];
2241    const int k1 = opts[2];
2242    const int s0 = opts[3];
2243    const int s1 = opts[4];
2244    const int p0 = opts[5];
2245    const int p1 = opts[6];
2246
2247    const int64_t IH = dst->src[0]->ne[1];
2248    const int64_t IW = dst->src[0]->ne[0];
2249
2250    const int64_t N = dst->ne[3];
2251    const int64_t OC = dst->ne[2];
2252    const int64_t OH = dst->ne[1];
2253    const int64_t OW = dst->ne[0];
2254
2255    const int parallel_elements = N * OC * OH * OW;
2256    const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
2257    sycl::range<3> block_nums(1, 1, num_blocks);
2258    main_stream->parallel_for(
2259        sycl::nd_range<3>(block_nums *
2260                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
2261                          sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
2262        [=](sycl::nd_item<3> item_ct1) {
2263            pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
2264                               parallel_elements, src0_dd, dst_dd, op,
2265                               item_ct1);
2266        });
2267}
2268
2269inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2270    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2271    GGML_ASSERT( dst->type == GGML_TYPE_F32);
2272    dpct::queue_ptr main_stream = ctx.stream();
2273    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2274    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2275    float *       dst_dd  = static_cast<float *>(dst->data);
2276
2277    const int64_t ne = ggml_nelements(dst->src[0]);
2278
2279    sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2280}
2281
2282inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2283    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2284    GGML_ASSERT( dst->type == GGML_TYPE_F32);
2285    dpct::queue_ptr main_stream = ctx.stream();
2286    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2287    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2288    float *       dst_dd  = static_cast<float *>(dst->data);
2289
2290    const int64_t ncols = dst->src[0]->ne[0];
2291    const int64_t nrows = ggml_nrows(dst->src[0]);
2292
2293    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2294}
2295
2296inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2297    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2298    GGML_ASSERT(dst->type == GGML_TYPE_F32);
2299
2300    dpct::queue_ptr main_stream = ctx.stream();
2301    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2302
2303    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2304    float *       dst_dd  = static_cast<float *>(dst->data);
2305
2306    const int64_t ncols = dst->src[0]->ne[0];
2307    const int64_t nrows = ggml_nrows(dst->src[0]);
2308
2309    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2310
2311    main_stream->parallel_for(
2312        sycl::range<1>(nrows),
2313        [=](sycl::id<1> row) {
2314            dst_dd[row] /= ncols;
2315        }
2316    );
2317}
2318
2319
2320inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2321    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2322    GGML_ASSERT(dst->type == GGML_TYPE_I32);
2323    dpct::queue_ptr main_stream = ctx.stream();
2324    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2325    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2326    int32_t *       dst_dd  = static_cast<int32_t *>(dst->data);
2327
2328
2329    const int64_t ncols = dst->src[0]->ne[0];
2330    const int64_t nrows = ggml_nrows(dst->src[0]);
2331
2332    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2333
2334    argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
2335                         main_stream, ctx.device);
2336}
2337
2338static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2339    const ggml_tensor * src0 = dst->src[0];
2340
2341    GGML_ASSERT(src0);
2342    GGML_ASSERT(src0->type == GGML_TYPE_F32);
2343    GGML_ASSERT(dst->type == GGML_TYPE_I32);
2344    GGML_ASSERT(ggml_is_contiguous(src0));
2345
2346    dpct::queue_ptr main_stream = ctx.stream();
2347    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2348
2349    const float * src0_dd = static_cast<const float *>(src0->data);
2350    int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2351
2352    const int k = dst->ne[0];
2353    const int64_t ncols = src0->ne[0];
2354    const int64_t nrows = ggml_nrows(src0);
2355
2356    GGML_ASSERT(k > 0 && k <= 32);
2357    GGML_ASSERT(k <= ncols);
2358
2359    top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
2360}
2361
2362inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2363    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2364    GGML_ASSERT( dst->type == GGML_TYPE_I32);
2365
2366    dpct::queue_ptr main_stream = ctx.stream();
2367    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2368    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2369    int32_t *       dst_dd  = static_cast<int32_t *>(dst->data);
2370
2371    const int64_t ncols = dst->src[0]->ne[0];
2372    const int64_t nrows = ggml_nrows(dst->src[0]);
2373
2374    argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2375}
2376
2377inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2378    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2379    GGML_ASSERT( dst->type == GGML_TYPE_F32);
2380    dpct::queue_ptr main_stream = ctx.stream();
2381    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2382    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2383    float *       dst_dd  = static_cast<float *>(dst->data);
2384
2385    const int64_t ne00 = dst->src[0]->ne[0];
2386    const int64_t ne01 = dst->src[0]->ne[1];
2387    const int nrows0 = ggml_nrows(dst->src[0]);
2388
2389    const int n_past = ((int32_t *) dst->op_params)[0];
2390
2391    diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2392}
2393
2394static void tri_f32_sycl(
2395    const float * src,
2396    float * dst,
2397    const int64_t ne0,
2398    const int64_t ne1,
2399    const int64_t ne2,
2400    const int64_t ne3,
2401    const ggml_tri_type ttype,
2402    dpct::queue_ptr main_stream
2403) {
2404    const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
2405
2406    main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
2407        const int64_t idx = (int64_t) tid[0];
2408
2409        const int64_t i0 = idx % ne0;
2410        const int64_t t1 = idx / ne0;
2411        const int64_t i1 = t1 % ne1;
2412
2413        bool keep = false;
2414        switch (ttype) {
2415            case GGML_TRI_TYPE_LOWER:      keep = (i0 <  i1); break;
2416            case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
2417            case GGML_TRI_TYPE_UPPER:      keep = (i0 >  i1); break;
2418            case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
2419            default: keep = false; break;
2420        }
2421
2422        dst[idx] = keep ? src[idx] : 0.0f;
2423    });
2424}
2425
2426static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2427    const ggml_tensor * src0 = dst->src[0];
2428    GGML_ASSERT(src0);
2429
2430    GGML_ASSERT(src0->type == GGML_TYPE_F32);
2431    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
2432    GGML_ASSERT(ggml_is_contiguous(src0));
2433    GGML_ASSERT(ggml_is_contiguous(dst));
2434    GGML_ASSERT(ggml_are_same_shape(src0, dst));
2435
2436    dpct::queue_ptr main_stream = ctx.stream();
2437    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2438
2439    const float * src0_dd = static_cast<const float *>(src0->data);
2440    float *       dst_dd  = static_cast<float *>(dst->data);
2441
2442    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2443
2444    const int64_t ne0 = src0->ne[0];
2445    const int64_t ne1 = src0->ne[1];
2446    const int64_t ne2 = src0->ne[2];
2447    const int64_t ne3 = src0->ne[3];
2448
2449    tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
2450}
2451
2452
2453inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2454    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2455    GGML_ASSERT( dst->type == GGML_TYPE_F32);
2456    dpct::queue_ptr main_stream = ctx.stream();
2457    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2458    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2459    float *       dst_dd  = static_cast<float *>(dst->data);
2460
2461    float scale;
2462    float bias;
2463    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
2464    memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));
2465
2466    scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
2467    /*
2468    DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2469    error codes. The call was replaced with 0. You need to rewrite this code.
2470    */
2471    SYCL_CHECK(0);
2472}
2473
2474static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
2475    static bool peer_access_enabled = false;
2476
2477    const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
2478
2479    if (peer_access_enabled == enable_peer_access) {
2480        return;
2481    }
2482
2483#ifdef NDEBUG
2484    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2485        SYCL_CHECK(ggml_sycl_set_device(i));
2486    }
2487
2488    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2489        SYCL_CHECK(ggml_sycl_set_device(i));
2490
2491        for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
2492            if (i == id_other) {
2493                continue;
2494            }
2495            if (i != main_device && id_other != main_device) {
2496                continue;
2497            }
2498
2499            // int can_access_peer;
2500            // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
2501            // if (can_access_peer) {
2502            //     if (enable_peer_access) {
2503            //         SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
2504            //     } else {
2505            //         SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
2506            //     }
2507            // }
2508        }
2509    }
2510#endif // NDEBUG
2511
2512    peer_access_enabled = enable_peer_access;
2513}
2514
2515template <template <int> typename quantize_f>
2516static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2517                                 const ggml_tensor *src1, ggml_tensor *dst,
2518                                 ggml_sycl_op_mul_mat_t op) try {
2519
2520    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
2521
2522    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
2523    const int64_t nrows1 = ggml_nrows(src1);
2524
2525    GGML_ASSERT(ne03 == ne13);
2526
2527    const int64_t ne0 = dst->ne[0];
2528    const int64_t ne1 = dst->ne[1];
2529
2530    const int nb2 = dst->nb[2];
2531    const int nb3 = dst->nb[3];
2532
2533    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
2534    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
2535    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
2536
2537    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
2538
2539    const int64_t i02_divisor = ne12 / ne02;
2540
2541    const size_t src0_ts = ggml_type_size(src0->type);
2542    const size_t src0_bs = ggml_blck_size(src0->type);
2543    const size_t q8_1_ts = sizeof(block_q8_1);
2544    const size_t q8_1_bs = QK8_1;
2545
2546    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2547    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2548
2549    const bool src0_is_contiguous = ggml_is_contiguous(src0);
2550    const bool src1_is_contiguous = ggml_is_contiguous(src1);
2551
2552    int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
2553
2554    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2555    GGML_ASSERT(!(split && ne02 > 1));
2556    GGML_ASSERT(!(split && ne03 > 1));
2557    GGML_ASSERT(!(split && ne02 < ne12));
2558
2559    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
2560    if (split) {
2561        // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
2562        // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
2563        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2564        tensor_split = buft_ctx->tensor_split;
2565    }
2566
2567    struct dev_data {
2568        ggml_sycl_pool_alloc<char> src0_dd_alloc;
2569        ggml_sycl_pool_alloc<float> src1_ddf_alloc;
2570        ggml_sycl_pool_alloc<char> src1_ddq_alloc;
2571        ggml_sycl_pool_alloc<float> dst_dd_alloc;
2572
2573        char *src0_dd = nullptr;
2574        float *src1_ddf = nullptr; // float
2575        char *src1_ddq = nullptr;  // q8_1
2576        float *dst_dd = nullptr;
2577
2578        int64_t row_low;
2579        int64_t row_high;
2580    };
2581
2582    dev_data dev[GGML_SYCL_MAX_DEVICES];
2583
2584    int used_devices = 0;
2585    queue_ptr main_stream = ctx.stream();
2586
2587    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2588        // by default, use all rows
2589        dev[i].row_low  = 0;
2590        dev[i].row_high = ne01;
2591
2592        // for multi GPU, get the row boundaries from tensor split
2593        // and round to mul_mat_q tile sizes
2594        if (split) {
2595            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
2596
2597            if (i != 0) {
2598                dev[i].row_low  = ne01*tensor_split[i];
2599                if (dev[i].row_low < ne01) {
2600                    dev[i].row_low -= dev[i].row_low % rounding;
2601                }
2602            }
2603
2604            if (i != ggml_sycl_info().device_count - 1) {
2605                dev[i].row_high  = ne01*tensor_split[i + 1];
2606                if (dev[i].row_high < ne01) {
2607                    dev[i].row_high -= dev[i].row_high % rounding;
2608                }
2609            }
2610        }
2611    }
2612
2613    constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2614                                                      no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
2615    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2616        if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2617            continue;
2618        }
2619
2620        used_devices++;
2621
2622        const bool src1_on_device = i == ctx.device;
2623        const bool  dst_on_device = i == ctx.device;
2624
2625        ggml_sycl_set_device(i);
2626        queue_ptr stream = ctx.stream(i, 0);
2627
2628        if (src0_is_contiguous) {
2629            dev[i].src0_dd = (char *) src0->data;
2630        } else {
2631            dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
2632        }
2633
2634        if (src1_on_device && src1_is_contiguous) {
2635            dev[i].src1_ddf = (float *) src1->data;
2636        } else {
2637            dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
2638        }
2639
2640        if constexpr(quantize_enabled) {
2641            dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2642
2643            if (src1_on_device && src1_is_contiguous) {
2644                scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2645                                                     /*num_src=*/2, " : converting src1 to Q8_1");
2646                try {
2647                    quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2648                } catch (sycl::exception const &exc) {
2649                    std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2650                              << ", line:" << __LINE__ << std::endl;
2651                    std::exit(1);
2652                }
2653            }
2654        }
2655
2656        if (dst_on_device) {
2657            dev[i].dst_dd = (float *) dst->data;
2658        } else {
2659            const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
2660            dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
2661        }
2662    }
2663
2664    // if multiple devices are used they need to wait for the main device
2665    // here an event is recorded that signals that the main device has finished calculating the input data
2666    if (split && used_devices > 1) {
2667        ggml_sycl_set_device(ctx.device);
2668        SYCL_CHECK(CHECK_TRY_ERROR(
2669            *src0_extra->events[ctx.device][0] =
2670                ctx.stream()->ext_oneapi_submit_barrier()));
2671    }
2672
2673    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
2674    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
2675        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
2676        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
2677        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2678            if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2679                continue;
2680            }
2681
2682            const bool src1_on_device = i == ctx.device;
2683            const bool  dst_on_device = i == ctx.device;
2684            const int64_t row_diff = dev[i].row_high - dev[i].row_low;
2685
2686            ggml_sycl_set_device(i);
2687            queue_ptr stream = ctx.stream(i, is);
2688
2689            // wait for main GPU data if necessary
2690            if (split && (i != ctx.device || is != 0)) {
2691                SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
2692                    {*src0_extra->events[ctx.device][0]})));
2693            }
2694
2695            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
2696                const int64_t i03 = i0 / ne12;
2697                const int64_t i02 = i0 % ne12;
2698
2699                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
2700
2701                // for split tensors the data begins at i0 == i0_offset_low
2702                char  *  src0_dd_i =  dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
2703                float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
2704                char  * src1_ddq_i = dev[i].src1_ddq +  src1_ddq_i_offset;
2705                float *   dst_dd_i =   dev[i].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
2706
2707                // the main device memory buffer can be on VRAM scratch, with space for all partial results
2708                // in that case an offset on dst_ddf_i is needed
2709                if (i == ctx.device) {
2710                    dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
2711                }
2712
2713                // copy src0, src1 to device if necessary
2714                if (src1_is_contiguous) {
2715                    if (i != ctx.device) {
2716                        if constexpr (quantize_enabled) {
2717                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2718                            SYCL_CHECK(
2719                                CHECK_TRY_ERROR(stream
2720                                                    ->memcpy(src1_ddq_i, src1_ddq_i_source,
2721                                                             src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2722                                                    .wait()));
2723                        } else {
2724                            float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2725                            src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
2726
2727                            SYCL_CHECK(
2728                                CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2729                                                               src1_ncols * ne10 * sizeof(float))));
2730                        }
2731                    }
2732                } else {
2733                    if (src1_on_device) {
2734                        SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2735                                                           src1_col_0 + src1_ncols, stream));
2736                    } else {
2737                        GGML_ABORT("src1 is non-contiguous and not on device");
2738                    }
2739
2740                    if constexpr (quantize_enabled) {
2741                        scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2742                                                             /*num_src=*/2, " : converting src1 to Q8_1");
2743                        try {
2744                            quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2745                                                                  src1_padded_col_size, stream);
2746                        } catch (const sycl::exception & exc) {
2747                            std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2748                                      << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2749                            std::exit(1);
2750                        }
2751                    }
2752                }
2753
2754                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
2755                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
2756                }
2757                if (src1->type == GGML_TYPE_F16) {
2758                    src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
2759                }
2760                // do the computation
2761                SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
2762                    dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2763
2764                // copy dst to host or other device if necessary
2765                if (!dst_on_device) {
2766                    void * dst_off_device = dst->data;
2767                    if (split) {
2768                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
2769                        // dst is NOT transposed.
2770                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
2771                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
2772                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
2773                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
2774                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
2775                        dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
2776
2777                        SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2778                            dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
2779                            row_diff * sizeof(float), row_diff * sizeof(float),
2780                            src1_ncols, dpct::device_to_device, *stream)));
2781                    } else {
2782                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
2783                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
2784                        dhf_dst_i += src1_col_0*ne0;
2785                        SYCL_CHECK(CHECK_TRY_ERROR(
2786                            stream->memcpy(dhf_dst_i, dst_dd_i,
2787                                           src1_ncols * ne0 * sizeof(float)).wait()));
2788                    }
2789                }
2790
2791                // add event for the main device to wait on until other device is done
2792                if (split && (i != ctx.device || is != 0)) {
2793                    SYCL_CHECK(CHECK_TRY_ERROR(
2794                        *src0_extra->events[i][is] =
2795                            stream->ext_oneapi_submit_barrier()));
2796                }
2797            }
2798        }
2799    }
2800
2801    // main device waits for all other devices to be finished
2802    if (split && ggml_sycl_info().device_count > 1) {
2803        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
2804        is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
2805
2806        ggml_sycl_set_device(ctx.device);
2807        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2808            if (dev[i].row_low == dev[i].row_high) {
2809                continue;
2810            }
2811            for (int64_t is = 0; is < is_max; ++is) {
2812                SYCL_CHECK(CHECK_TRY_ERROR(
2813                    ctx.stream()->ext_oneapi_submit_barrier(
2814                        {*src0_extra->events[i][is]})));
2815            }
2816        }
2817    }
2818}
2819catch (sycl::exception const &exc) {
2820  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2821            << ", line:" << __LINE__ << std::endl;
2822  std::exit(1);
2823}
2824
2825static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2826    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2827    ggml_sycl_op_repeat_back(ctx, dst);
2828}
2829
2830static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2831    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2832    ggml_sycl_op_get_rows(ctx, dst);
2833}
2834
2835static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2836    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2837    ggml_sycl_op_norm(ctx, dst);
2838}
2839
2840static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2841    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2842    ggml_sycl_op_rms_norm(ctx, dst);
2843}
2844
2845static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2846    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2847    ggml_sycl_op_rms_norm_back(ctx, dst);
2848}
2849
2850static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2851    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2852    ggml_sycl_op_l2_norm(ctx, dst);
2853}
2854
2855static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2856    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2857    ggml_sycl_op_group_norm(ctx, dst);
2858}
2859
2860static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2861                                       const ggml_tensor *src1,
2862                                       ggml_tensor *dst) try {
2863    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
2864    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2865    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
2866    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
2867    GGML_ASSERT(src0->type == GGML_TYPE_F16);
2868    GGML_ASSERT(src1->type == GGML_TYPE_F32);
2869
2870    const int64_t ne00 = src0->ne[0];
2871    const int64_t ne01 = src0->ne[1];
2872    const int64_t ne02 = src0->ne[2];
2873
2874    const int64_t ne12 = src1->ne[2];
2875
2876    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2877    queue_ptr main_stream = ctx.stream();
2878
2879    void  * src0_ddq = src0->data;
2880    float * src1_ddf = (float *) src1->data;
2881    float * dst_ddf  = (float *) dst->data;
2882
2883    ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
2884}
2885catch (sycl::exception const &exc) {
2886  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2887            << ", line:" << __LINE__ << std::endl;
2888  std::exit(1);
2889}
2890
2891static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2892                                     const ggml_tensor *src1,
2893                                     ggml_tensor *dst) try {
2894    GGML_ASSERT(!ggml_is_transposed(src0));
2895    GGML_ASSERT(!ggml_is_transposed(src1));
2896    GGML_ASSERT(!ggml_is_permuted(src0));
2897    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2898    GGML_ASSERT(src0->type == GGML_TYPE_F16);
2899    GGML_ASSERT(src1->type == GGML_TYPE_F32);
2900    GGML_ASSERT(src1->ne[1] == 1);
2901    GGML_ASSERT(src1->ne[3] == 1);
2902
2903    const int64_t ne00 = src0->ne[0];
2904    const int64_t ne01 = src0->ne[1];
2905    const int64_t ne02 = src0->ne[2];
2906
2907    const int64_t nb01 = src0->nb[1];
2908    const int64_t nb02 = src0->nb[2];
2909
2910    const int64_t ne12 = src1->ne[2];
2911    const int64_t nb11 = src1->nb[1];
2912
2913    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2914    queue_ptr main_stream = ctx.stream();
2915
2916    void  * src0_ddq = src0->data;
2917    float * src1_ddf = (float *) src1->data;
2918    float * dst_ddf  = (float *) dst->data;
2919
2920    const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2921    const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2922    const int64_t channel_stride_y = nb11 / sizeof(float);
2923
2924    ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
2925}
2926catch (sycl::exception const &exc) {
2927  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2928            << ", line:" << __LINE__ << std::endl;
2929  std::exit(1);
2930}
2931
2932static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2933                                   const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2934                                   size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2935                                   int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2936    const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2937    const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2938
2939    if (i13 >= ne13 || i12 >= ne12) {
2940        return;
2941    }
2942
2943    const int64_t i03 = i13 / r3;
2944    const int64_t i02 = i12 / r2;
2945
2946    const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2947    const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2948    uint8_t *       dst_bytes  = static_cast<uint8_t *>(dst);
2949
2950    ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2951    ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2952    ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2953}
2954
2955static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2956                                           const ggml_tensor * src1, ggml_tensor * dst) try {
2957    GGML_ASSERT(!ggml_is_transposed(src0));
2958    GGML_ASSERT(!ggml_is_transposed(src1));
2959    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2960    GGML_ASSERT(src0->type == GGML_TYPE_F16);
2961    GGML_ASSERT(dst->type == GGML_TYPE_F32);
2962
2963    GGML_TENSOR_BINARY_OP_LOCALS
2964
2965    // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2966    // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2967    GGML_ASSERT(ggml_is_contiguous(dst));
2968
2969    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2970    queue_ptr queue = ctx.stream();
2971
2972    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
2973
2974    const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2975    float *            dst_ddf  = static_cast<float *>(dst->data);
2976
2977    const sycl::half * src1_f16       = static_cast<const sycl::half *>(src1->data);
2978    const size_t       type_size_src0 = ggml_type_size(src0->type);
2979    const size_t       type_size_src1 = ggml_type_size(src1->type);
2980
2981    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2982    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2983
2984    // SRC1 strides
2985    int64_t                          s11 = nb11 / type_size_src1;
2986    int64_t                          s12 = nb12 / type_size_src1;
2987    int64_t                          s13 = nb13 / type_size_src1;
2988    ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2989
2990    // convert src1 to fp16
2991    if (src1->type != GGML_TYPE_F16) {
2992        scope_op_debug_print    scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2993                                                " : converting src1 to fp16");
2994
2995        // iterate tensor dims and find the slowest moving dim and stride
2996        int last_dim=0;
2997        int last_str=0;
2998        size_t largest_str=0;
2999        for(int i = 0; i< 4; i++){
3000            // last stride is always the largest
3001            if(src1->nb[i] == largest_str){
3002                if(src1->ne[last_dim] == 1){
3003                    last_str = i;
3004                    last_dim = i;
3005                }
3006            }
3007            if(src1->nb[i] > largest_str){
3008                largest_str = src1->nb[i];
3009                last_str = i;
3010                last_dim = i;
3011            }
3012
3013        }
3014#if GGML_SYCL_DNNL
3015        // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
3016        const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
3017        src1_f16_alloc.alloc(ne_src1);
3018        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
3019        GGML_ASSERT(to_fp16_sycl != nullptr);
3020        to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
3021# else
3022        const int64_t ne_src1 = ggml_nelements(src1);
3023        src1_f16_alloc.alloc(ne_src1);
3024        const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
3025        GGML_ASSERT(to_fp16_nc_sycl != nullptr);
3026        to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
3027#endif
3028
3029        src1_f16 = src1_f16_alloc.get();
3030        s11      = ne10;
3031        s12      = ne11 * s11;
3032        s13      = ne12 * s12;
3033
3034        is_src1_cont_2 = true;
3035    }
3036
3037    ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
3038
3039    dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
3040    dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float;
3041
3042    // dst strides
3043    size_t nbd2 = dst->nb[2];
3044    size_t nbd3 = dst->nb[3];
3045
3046    const float alpha_f32 = 1.0f;
3047    const float beta_f32  = 0.0f;
3048
3049    const void * alpha = &alpha_f32;
3050    const void * beta  = &beta_f32;
3051
3052    GGML_ASSERT(ne12 % ne02 == 0);
3053    GGML_ASSERT(ne13 % ne03 == 0);
3054    GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
3055    GGML_ASSERT(ne10 == ne00);
3056
3057    // broadcast factors
3058    const int64_t r2 = ne12 / ne02;
3059    const int64_t r3 = ne13 / ne03;
3060
3061#if GGML_SYCL_DNNL
3062    if (!g_ggml_sycl_disable_dnn) {
3063            int64_t str_a0 = nb00 / type_size_src0;
3064            int64_t str_a1 = nb01 / type_size_src0;
3065            int64_t str_a2 = nb02 / type_size_src0;
3066
3067            int64_t str_b0 = nb10 / type_size_src1;
3068            int64_t str_b1 = nb11 / type_size_src1;
3069            int64_t str_b2 = nb12 / type_size_src1;
3070
3071            auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
3072                                                const sycl::half *src1, float *dst,
3073                                                int64_t a0, int64_t a1, int64_t batcha,
3074                                                int64_t /*b0*/, int64_t b1, int64_t batchb,
3075                                                int64_t sa0, int64_t sa1, int64_t sa2,
3076                                                int64_t sb0, int64_t sb1, int64_t sb2,
3077                                                int64_t sd2) {
3078                bool supported_broadcast = batchb == batcha ? true
3079                        : batchb == 1 || batcha == 1        ? true
3080                                                            : false;
3081                if (supported_broadcast) {
3082                    DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
3083                            DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
3084                            DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
3085                            DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
3086                } else {
3087                    // iterate over batches from smaller set of matrices (matrix 0)
3088                    int64_t batches0 = batcha;
3089                    int64_t batches1 = batchb;
3090
3091                    if (batches0 > batches1) {
3092                        int64_t num_mul_mats = batches1;
3093                        int64_t sub_batch = batches0 / num_mul_mats;
3094                        // src0 is batched and bigger, shift and multiply with src1
3095                        for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
3096                            const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
3097                            const sycl::half *src1_shifted = src1 + (sb2 * i0);
3098                            float *dst_shifted = dst + (sd2 * i0 * sub_batch);
3099                            DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
3100                                    DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
3101                                    src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
3102                                    sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
3103                                    queue, sub_batch, 1);
3104                        }
3105                    } else {
3106                        int64_t num_mul_mats = batches0;
3107                        int64_t sub_batch = batches1 / num_mul_mats;
3108                        // src1 is batched and bigger, shift and multiply with src0
3109                        for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
3110                            const sycl::half *src0_shifted = src0 + (sa2 * i1);
3111                            const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
3112                            float *dst_shifted = dst + (sd2 * i1 * sub_batch);
3113                            DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
3114                                    DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
3115                                    src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
3116                                    sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
3117                                    queue, 1, sub_batch);
3118                        }
3119                    }
3120                }
3121            };
3122
3123            const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
3124            const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
3125            const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
3126            const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
3127            if (cont_batches_dim2_a && cont_batches_dim2_b) {
3128                // A batch is considered contiguous if the dimension 2 is not strided
3129                int64_t batches0 = ne02 * ne03;
3130                int64_t batches1 = ne12 * ne13;
3131                launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
3132                        ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
3133                        str_b2, nb2 / sizeof(float));
3134            } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
3135                // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
3136                int64_t batches0 = ne02 * ne03;
3137                int64_t batches1 = ne12 * ne13;
3138                int64_t str_a3 = nb03 / type_size_src0;
3139                int64_t str_b3 = nb13 / type_size_src1;
3140                launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
3141                        ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
3142                        str_b3, nb2 / sizeof(float));
3143            } else {
3144                for (int64_t b_a = 0; b_a < ne03; b_a++) {
3145                    const sycl::half *src0_f16_shifted
3146                            = src0_f16 + (nb03 * b_a / type_size_src0);
3147                    const sycl::half *src1_f16_shifted
3148                            = src1_f16 + (nb13 * b_a / type_size_src1);
3149                    float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
3150                    int64_t batches0 = ne02;
3151                    int64_t batches1 = ne12;
3152                    launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
3153                            ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
3154                            str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
3155                }
3156            }
3157
3158    }
3159    else
3160#endif
3161    {
3162        if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
3163            // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
3164            const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
3165            const int64_t smb = ne12 == 1 ? s13       : s12;
3166
3167            // there is no broadcast and src0, src1 are contiguous across dims 2, 3
3168            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
3169                                                        oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3170                                                        src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
3171                                                        src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
3172                                                        mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
3173        } else {
3174            const int ne23 = ne12 * ne13;
3175
3176            ggml_sycl_pool_alloc<const void *>         ptrs_src(ctx.pool(), 2 * ne23);
3177            ggml_sycl_pool_alloc<void *>               ptrs_dst(ctx.pool(), 1 * ne23);
3178            ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
3179
3180            sycl::range<3> block_dims(1, ne12, ne13);
3181            queue->submit([&](sycl::handler & cgh) {
3182                const void ** ptrs_src_get = ptrs_src.get();
3183                void **       ptrs_dst_get = ptrs_dst.get();
3184                size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
3185                size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
3186                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3187                    k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
3188                                           nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
3189                });
3190            });
3191
3192            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3193                *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3194                (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3195                (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
3196                (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
3197        }
3198    }
3199} catch (const sycl::exception & exc) {
3200    std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3201    std::exit(1);
3202}
3203
3204enum class mul_mat_algo {
3205    DMMV         = 0,
3206    MMVQ         = 1,
3207    MUL_MAT_SYCL = 2,
3208};
3209
3210inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3211    // TODO: accuracy issues in MMQ
3212    GGML_UNUSED(type);
3213    return false;
3214}
3215
3216inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
3217    switch (type) {
3218        case GGML_TYPE_Q4_0:
3219            return true;
3220        case GGML_TYPE_Q4_K:
3221        case GGML_TYPE_Q6_K:
3222            return !g_ggml_sycl_prioritize_dmmv;
3223        default:
3224            return false;
3225    }
3226}
3227
3228inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
3229    switch (type) {
3230        case GGML_TYPE_Q4_0:
3231            return true;
3232        default:
3233            return false;
3234    }
3235}
3236
3237inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
3238    switch (type) {
3239        case GGML_TYPE_Q4_0:
3240        case GGML_TYPE_Q4_K:
3241        case GGML_TYPE_Q6_K:
3242            return true;
3243        default:
3244            return false;
3245    }
3246}
3247
3248static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3249    switch (type) {
3250        case GGML_TYPE_Q4_0:
3251        case GGML_TYPE_Q4_1:
3252        case GGML_TYPE_Q5_0:
3253        case GGML_TYPE_Q5_1:
3254        case GGML_TYPE_Q8_0:
3255        case GGML_TYPE_Q2_K:
3256        case GGML_TYPE_Q3_K:
3257        case GGML_TYPE_Q4_K:
3258        case GGML_TYPE_Q5_K:
3259        case GGML_TYPE_Q6_K:
3260        case GGML_TYPE_F16:
3261            return true;
3262        default:
3263            return false;
3264    }
3265}
3266
3267// Helper functions to unify device memory allocation for both async and sync paths
3268static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3269    bool use_async = g_ggml_sycl_use_async_mem_op;
3270#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3271    if (use_async) {
3272        return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
3273    }
3274#else
3275    // If async allocation extension is not available, use_async should always be false.
3276    GGML_ASSERT(!use_async);
3277#endif
3278    return sycl::malloc(size, *stream, sycl::usm::alloc::device);
3279}
3280
3281static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3282    bool use_async = g_ggml_sycl_use_async_mem_op;
3283#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3284    if (use_async) {
3285        syclex::async_free(*stream, ptr);
3286        return;
3287    }
3288#else
3289    // If async allocation extension is not available, use_async should always be false.
3290    GGML_ASSERT(!use_async);
3291#endif
3292    sycl::free(ptr, *stream);
3293}
3294
3295static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3296                            dpct::queue_ptr stream) {
3297    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3298
3299    sycl::event copy_event;
3300    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3301    if (!g_ggml_sycl_use_async_mem_op) {
3302        copy_event.wait();
3303    }
3304
3305    GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3306    GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3307    int offset_blks = offset / sizeof(block_q4_0);
3308    auto qs_ptr      = data_device + offset_blks * QK4_0 / 2;
3309    auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3310
3311    auto reorder_event = stream->parallel_for(
3312        size / sizeof(block_q4_0),
3313            [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3314            const block_q4_0* x = (const block_q4_0*)tmp_buf;
3315            const int ib = i;
3316
3317            for (int j = 0; j < QK4_0/2; j ++)
3318            {
3319                *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3320            }
3321            *(d_ptr + ib) = x[ib].d;
3322        });
3323    if (!g_ggml_sycl_use_async_mem_op) {
3324        reorder_event.wait_and_throw();
3325    }
3326    sycl_ext_free(stream, tmp_buf);
3327}
3328
3329static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3330    GGML_ASSERT(size % sizeof(block_q4_K) == 0);
3331    GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
3332
3333    const int nblocks = size / sizeof(block_q4_K);
3334
3335    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3336
3337    sycl::event copy_event;
3338    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3339    if (!g_ggml_sycl_use_async_mem_op) {
3340        copy_event.wait();
3341    }
3342
3343    auto * qs_ptr     = data_device;
3344    auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3345    auto * dm_ptr     = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3346
3347    auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3348        const block_q4_K * x  = (const block_q4_K *) tmp_buf;
3349        const int          ib = i;
3350
3351        for (int j = 0; j < QK_K / 2; ++j) {
3352            qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
3353        }
3354
3355        for (int j = 0; j < K_SCALE_SIZE; ++j) {
3356            scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
3357        }
3358
3359        dm_ptr[ib] = x[ib].dm;
3360    });
3361    if (!g_ggml_sycl_use_async_mem_op) {
3362        reorder_event.wait_and_throw();
3363    }
3364    sycl_ext_free(stream, tmp_buf);
3365}
3366
3367static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3368    GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3369    GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3370
3371    const int nblocks = size / sizeof(block_q6_K);
3372
3373    uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3374
3375    sycl::event copy_event;
3376    SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3377    if (!g_ggml_sycl_use_async_mem_op) {
3378        copy_event.wait();
3379    }
3380
3381    auto *       ql_ptr     = data_device;
3382    auto *       qh_ptr     = ql_ptr + (QK_K / 2) * nblocks;
3383    auto *       scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3384    sycl::half * dm_ptr     = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3385
3386    auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3387        const block_q6_K * x  = (const block_q6_K *) tmp_buf;
3388        const int          ib = i;
3389
3390        const uint8_t * ql              = x[ib].ql;
3391        const uint8_t * qh              = x[ib].qh;
3392        uint8_t *       base_ql_ptr     = ql_ptr + (QK_K / 2) * ib;
3393        uint8_t *       base_qh_ptr     = qh_ptr + (QK_K / 4) * ib;
3394        uint8_t *       base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3395
3396        for (int j = 0; j < QK_K / 2; ++j) {
3397            base_ql_ptr[j] = ql[j];
3398        }
3399        for (int j = 0; j < QK_K / 4; ++j) {
3400            base_qh_ptr[j] = qh[j];
3401        }
3402
3403        for (int j = 0; j < QK_K / 16; ++j) {
3404            base_scales_ptr[j] = x[ib].scales[j];
3405        }
3406
3407        dm_ptr[ib] = x[ib].d;
3408    });
3409    if (!g_ggml_sycl_use_async_mem_op) {
3410        reorder_event.wait_and_throw();
3411    }
3412    sycl_ext_free(stream, tmp_buf);
3413}
3414
3415static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3416    uint8_t * data_device = (uint8_t *) src0->data;
3417    size_t ncols = src0->ne[0];
3418    size_t nrows = src0->ne[1];
3419    size_t size = ggml_nbytes(src0);
3420
3421    switch (src0->type) {
3422        case GGML_TYPE_Q4_0:
3423            reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3424            break;
3425        case GGML_TYPE_Q4_K:
3426            reorder_qw_q4_k(data_device, size, 0, stream);
3427            break;
3428        case GGML_TYPE_Q6_K:
3429            reorder_qw_q6_k(data_device, size, 0, stream);
3430            break;
3431        default:
3432            GGML_ABORT("reorder_qw() called with unsupported type");
3433            break;
3434    }
3435}
3436
3437static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3438    return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3439            ctx.opt_feature.reorder &&      //allow this device due to good perf, skip the devices with bad perf.
3440            dst->op == GGML_OP_MUL_MAT &&   //limit to some supported cases of Q4_0, to do for more cases.
3441            dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3442}
3443
3444static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3445                            ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3446    if (!should_reorder_tensor(*ctx, dst)) {
3447        return;
3448    }
3449
3450    ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3451    if (!extra || extra->optimized_feature.reorder) {
3452        return;  // Skip permutations and already reordered tensors
3453    }
3454
3455    switch (mm_algorithm) {
3456        case mul_mat_algo::DMMV:
3457            if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3458                return;
3459            }
3460            break;
3461        case mul_mat_algo::MMVQ:
3462            if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3463                return;
3464            }
3465            break;
3466        case mul_mat_algo::MUL_MAT_SYCL:
3467            if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3468                return;
3469            }
3470            break;
3471    }
3472
3473    reorder_qw(src0, ctx->stream());
3474    extra->optimized_feature.reorder = true;  // Used to decode/dequan in next steps and avoid re-reordering
3475}
3476
3477
3478static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3479    return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3480           src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3481}
3482
3483static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3484    return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3485           src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3486}
3487
3488static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3489    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3490    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3491    int64_t min_compute_capability = INT_MAX;
3492
3493    if (split) {
3494        ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3495            (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3496        auto & tensor_split = buft_ctx->tensor_split;
3497        for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
3498            // skip devices that are not going to do any work:
3499            if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
3500                continue;
3501            }
3502
3503            if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
3504                min_compute_capability = ggml_sycl_info().devices[id].cc;
3505            }
3506        }
3507    } else {
3508        min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3509    }
3510
3511    // check data types and tensor shapes for custom matrix multiplication kernels:
3512    bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
3513
3514    bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
3515
3516    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
3517        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
3518
3519
3520    // mmvq and mmq need the __dp4a instruction which is available for gen12+
3521    // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3522    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
3523#ifdef SYCL_USE_XMX
3524    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3525#endif // SYCL_USE_XMX
3526
3527    // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3528    // is enabled takes precedence over DMMV, the current if-else implementation
3529    // requires disabling DMMV if both conditions are met
3530    if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
3531                                          ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3532        use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3533    }
3534
3535    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
3536        // TODO: Refactor and cleanup of mul mat dispatching.
3537        if (src0->ne[3] == 1 && src1->ne[3] == 1) {
3538            // KQ single-batch
3539            // mmv p021 was specific for these dimensions
3540            ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
3541        } else {
3542            // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3543            ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3544        }
3545    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
3546        // KQV single-batch
3547        ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3548    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
3549        // KQ + KQV multi-batch
3550        ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3551    } else if (use_dequantize_mul_mat_vec) {
3552        opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3553        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
3554    } else if (use_mul_mat_vec_q) {
3555        opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3556        ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3557        if (extra && extra->optimized_feature.reorder) {
3558            ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3559        } else {
3560            ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3561        }
3562    } else if (use_mul_mat_q) {
3563        ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
3564    } else {
3565        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
3566    }
3567}
3568
3569
3570struct mmid_row_mapping {
3571    int32_t i1;
3572    int32_t i2;
3573};
3574
3575__dpct_inline__ static void k_copy_src1_to_contiguous(
3576    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3577    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3578    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3579    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3580    const sycl::nd_item<3> &item_ct1, int &src1_row) {
3581    int32_t iid1 = item_ct1.get_group(2);
3582    int32_t id = item_ct1.get_group(1);
3583
3584    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3585
3586    if (row_id_i != i02) {
3587        return;
3588    }
3589
3590    const int64_t i11 = id % ne11;
3591    const int64_t i12 = iid1;
3592
3593    if (item_ct1.get_local_id(2) == 0) {
3594        src1_row =
3595            dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3596                cur_src1_row, 1);
3597        row_mapping[src1_row] = {id, iid1};
3598    }
3599    /*
3600    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3601    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3602    performance if there is no access to global memory.
3603    */
3604    item_ct1.barrier();
3605
3606    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
3607    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
3608
3609#pragma unroll
3610    for (int i = item_ct1.get_local_id(2); i < ne10;
3611         i += item_ct1.get_local_range(2)) {
3612        src1_row_contiguous[i] = src1_row_original[i];
3613    }
3614}
3615
3616__dpct_inline__ static void k_copy_dst_from_contiguous(
3617    char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
3618    const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
3619    size_t nb2, const sycl::nd_item<3> &item_ct1) {
3620    int32_t i = item_ct1.get_group(2);
3621
3622    const int32_t i1 = row_mapping[i].i1;
3623    const int32_t i2 = row_mapping[i].i2;
3624
3625    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
3626    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
3627
3628#pragma unroll
3629    for (int j = item_ct1.get_local_id(2); j < ne0;
3630         j += item_ct1.get_local_range(2)) {
3631        dst_row_original[j] = dst_row_contiguous[j];
3632    }
3633}
3634
3635static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3636                                 ggml_tensor *dst) try {
3637    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
3638    const ggml_tensor *src0 = dst->src[0];
3639    const ggml_tensor *src1 = dst->src[1];
3640    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
3641
3642    const ggml_tensor *ids = dst->src[2];
3643    GGML_TENSOR_BINARY_OP_LOCALS
3644
3645    const queue_ptr stream = ctx.stream();
3646
3647    const int64_t n_as = ne02;
3648    const int64_t n_ids = ids->ne[0];
3649
3650    std::vector<char> ids_host(ggml_nbytes(ids));
3651    const char * ids_dev = (const char *) ids->data;
3652
3653    SYCL_CHECK(CHECK_TRY_ERROR(
3654        stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
3655    SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
3656
3657    ggml_tensor src0_row = *src0;
3658    ggml_tensor src1_row = *src1;
3659    ggml_tensor dst_row = *dst;
3660
3661    char *src0_original = (char *)src0->data;
3662    char *src1_original = (char *)src1->data;
3663    char *dst_original = (char *)dst->data;
3664
3665    src0_row.ne[2] = 1;
3666    src0_row.ne[3] = 1;
3667    src0_row.nb[3] = nb02;
3668
3669    src1_row.ne[1] = 1;
3670    src1_row.ne[2] = 1;
3671    src1_row.ne[3] = 1;
3672    src1_row.nb[2] = nb11;
3673    src1_row.nb[3] = nb11;
3674
3675    dst_row.ne[1] = 1;
3676    dst_row.ne[2] = 1;
3677    dst_row.ne[3] = 1;
3678    dst_row.nb[2] = nb1;
3679    dst_row.nb[3] = nb1;
3680    if (ne12 == 1) {
3681        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3682            for (int64_t id = 0; id < n_ids; id++) {
3683                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
3684                GGML_ASSERT(i02 >= 0 && i02 < n_as);
3685
3686                const int64_t i11 = id % ne11;
3687                const int64_t i12 = iid1;
3688
3689                const int64_t i1 = id;
3690                const int64_t i2 = i12;
3691
3692            src0_row.data = src0_original + i02*nb02;
3693            src1_row.data = src1_original + i11*nb11 + i12*nb12;
3694            dst_row.data = dst_original + i1*nb1 + i2*nb2;
3695
3696            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3697            }
3698        }
3699    } else {
3700        ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
3701        ggml_sycl_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
3702
3703        src1_row.data = src1_contiguous.get();
3704        dst_row.data  =  dst_contiguous.get();
3705
3706        for (int64_t i02 = 0; i02 < n_as; i02++) {
3707            int64_t num_src1_rows = 0;
3708            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3709                for (int64_t id = 0; id < n_ids; id++) {
3710                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
3711
3712                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
3713
3714                    if (row_id_i != i02) {
3715                        continue;
3716                    }
3717
3718                    num_src1_rows++;
3719                }
3720            }
3721
3722            if (num_src1_rows == 0) {
3723                continue;
3724            }
3725
3726
3727            ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
3728            ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
3729            SYCL_CHECK(CHECK_TRY_ERROR(
3730                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
3731
3732            const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
3733            assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
3734
3735            {
3736                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
3737                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3738                stream->submit([&](sycl::handler &cgh) {
3739                    sycl::local_accessor<int, 0> src1_row_acc(cgh);
3740
3741                    char *__restrict src1_contiguous_get =
3742                        src1_contiguous.get();
3743                    int *__restrict dev_cur_src1_row_get =
3744                        dev_cur_src1_row.get();
3745                    mmid_row_mapping *__restrict dev_row_mapping_get =
3746                        dev_row_mapping.get();
3747                    size_t ids_nb_ct6 = ids->nb[1];
3748                    size_t ids_nb_ct7 = ids->nb[0];
3749
3750                    cgh.parallel_for(
3751                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3752                        [=](sycl::nd_item<3> item_ct1) {
3753                            k_copy_src1_to_contiguous(
3754                                src1_original, src1_contiguous_get,
3755                                dev_cur_src1_row_get,
3756                                dev_row_mapping_get, ids_dev, i02,
3757                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
3758                                item_ct1, src1_row_acc);
3759                        });
3760                });
3761            }
3762
3763            src0_row.data = src0_original + i02*nb02;
3764
3765            GGML_ASSERT(nb11 == sizeof(float)*ne10);
3766            GGML_ASSERT(nb1 == sizeof(float)*ne0);
3767            src1_row.ne[1] = num_src1_rows;
3768
3769            src1_row.nb[1] = nb11;
3770            src1_row.nb[2] = num_src1_rows*nb11;
3771            src1_row.nb[3] = num_src1_rows*nb11;
3772
3773            dst_row.ne[1] = num_src1_rows;
3774            dst_row.nb[1] = nb1;
3775            dst_row.nb[2] = num_src1_rows*nb1;
3776            dst_row.nb[3] = num_src1_rows*nb1;
3777
3778            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3779
3780            {
3781                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
3782                sycl::range<3> grid_dims(1, 1, num_src1_rows);
3783                stream->submit([&](sycl::handler &cgh) {
3784                    const char *__restrict dst_contiguous_get =
3785                        dst_contiguous.get();
3786                    const mmid_row_mapping *__restrict dev_row_mapping_get =
3787                        dev_row_mapping.get();
3788
3789                    cgh.parallel_for(
3790                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3791                        [=](sycl::nd_item<3> item_ct1) {
3792                            k_copy_dst_from_contiguous(dst_original,
3793                                                       dst_contiguous_get,
3794                                                       dev_row_mapping_get,
3795                                                       ne0, nb1, nb2, item_ct1);
3796                        });
3797                });
3798            }
3799        }
3800    }
3801}
3802catch (sycl::exception const &exc) {
3803  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3804            << ", line:" << __LINE__ << std::endl;
3805  std::exit(1);
3806}
3807
3808static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3809    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3810    ggml_sycl_op_scale(ctx, dst);
3811}
3812
3813static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3814    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3815    ggml_sycl_op_diag_mask_inf(ctx, dst);
3816}
3817
3818static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3819    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3820    ggml_sycl_op_pool2d(ctx, dst);
3821}
3822
3823static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3824    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3825    ggml_sycl_op_im2col(ctx, dst);
3826}
3827
3828static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3829    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3830    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3831    ggml_sycl_op_sum(ctx, dst);
3832}
3833
3834static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3835    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3836    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3837    ggml_sycl_op_sum_rows(ctx, dst);
3838}
3839
3840static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3841    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3842    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3843    ggml_sycl_op_mean(ctx, dst);
3844}
3845
3846static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3847    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3848    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3849    ggml_sycl_op_argsort(ctx, dst);
3850}
3851
3852static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3853    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3854    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3855    ggml_sycl_op_argmax(ctx, dst);
3856}
3857
3858
3859static void ggml_sycl_set_main_device(const int main_device) try {
3860    if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
3861        return;
3862    }
3863    check_allow_gpu_index(main_device);
3864    dpct::select_device(main_device);
3865
3866    if (g_ggml_sycl_debug) {
3867        dpct::device_info prop;
3868        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
3869            prop, dpct::dev_mgr::instance().get_device(main_device))));
3870        GGML_LOG_INFO("Using device %d (%s) as main device\n",
3871                main_device, prop.get_name());
3872    }
3873}
3874catch (sycl::exception const &exc) {
3875  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3876            << ", line:" << __LINE__ << std::endl;
3877  std::exit(1);
3878}
3879
3880static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3881    if (!g_sycl_loaded) return false;
3882
3883    if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
3884        ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
3885    }
3886
3887    switch (dst->op) {
3888        case GGML_OP_ARGMAX:
3889            ggml_sycl_argmax(ctx, dst);
3890            break;
3891        case GGML_OP_CONV_TRANSPOSE_1D:
3892            ggml_sycl_op_conv_transpose_1d(ctx, dst);
3893            break;
3894        case GGML_OP_REPEAT:
3895            ggml_sycl_repeat(ctx, dst);
3896            break;
3897        case GGML_OP_REPEAT_BACK:
3898            ggml_sycl_repeat_back(ctx, dst);
3899            break;
3900        case GGML_OP_GET_ROWS:
3901            ggml_sycl_get_rows(ctx, dst);
3902            break;
3903        case GGML_OP_SET:
3904            ggml_sycl_op_set(ctx, dst);
3905            break;
3906        case GGML_OP_SET_ROWS:
3907            ggml_sycl_op_set_rows(ctx, dst);
3908            break;
3909        case GGML_OP_DUP:
3910            ggml_sycl_dup(ctx, dst);
3911            break;
3912        case GGML_OP_ADD:
3913        case GGML_OP_ADD1: // TODO: more efficient implementation
3914            ggml_sycl_add(ctx, dst);
3915            break;
3916        case GGML_OP_ADD_ID:
3917            ggml_sycl_add_id(ctx, dst);
3918            break;
3919        case GGML_OP_SUB:
3920            ggml_sycl_sub(ctx, dst);
3921            break;
3922        case GGML_OP_COUNT_EQUAL:
3923            ggml_sycl_count_equal(ctx, dst);
3924            break;
3925        case GGML_OP_ACC:
3926            ggml_sycl_acc(ctx, dst);
3927            break;
3928        case GGML_OP_MUL:
3929            ggml_sycl_mul(ctx, dst);
3930            break;
3931        case GGML_OP_LOG:
3932            ggml_sycl_log(ctx, dst);
3933            break;
3934        case GGML_OP_DIV:
3935            ggml_sycl_div(ctx, dst);
3936            break;
3937        case GGML_OP_UNARY:
3938            switch (ggml_get_unary_op(dst)) {
3939                case GGML_UNARY_OP_NEG:
3940                    ggml_sycl_neg(ctx, dst);
3941                    break;
3942                case GGML_UNARY_OP_STEP:
3943                    ggml_sycl_step(ctx, dst);
3944                    break;
3945                case GGML_UNARY_OP_GELU:
3946                    ggml_sycl_gelu(ctx, dst);
3947                    break;
3948                case GGML_UNARY_OP_SILU:
3949                    ggml_sycl_silu(ctx, dst);
3950                    break;
3951                case GGML_UNARY_OP_GELU_QUICK:
3952                    ggml_sycl_gelu_quick(ctx, dst);
3953                    break;
3954                case GGML_UNARY_OP_GELU_ERF:
3955                    ggml_sycl_gelu_erf(ctx, dst);
3956                    break;
3957                case GGML_UNARY_OP_TANH:
3958                    ggml_sycl_tanh(ctx, dst);
3959                    break;
3960                case GGML_UNARY_OP_RELU:
3961                    ggml_sycl_relu(ctx, dst);
3962                    break;
3963                case GGML_UNARY_OP_SIGMOID:
3964                    ggml_sycl_sigmoid(ctx, dst);
3965                    break;
3966                case GGML_UNARY_OP_HARDSIGMOID:
3967                    ggml_sycl_hardsigmoid(ctx, dst);
3968                    break;
3969                case GGML_UNARY_OP_HARDSWISH:
3970                    ggml_sycl_hardswish(ctx, dst);
3971                    break;
3972                case GGML_UNARY_OP_EXP:
3973                    ggml_sycl_exp(ctx, dst);
3974                    break;
3975                case GGML_UNARY_OP_SOFTPLUS:
3976                    ggml_sycl_softplus(ctx, dst);
3977                    break;
3978                case GGML_UNARY_OP_SGN:
3979                    ggml_sycl_sgn(ctx, dst);
3980                    break;
3981                case GGML_UNARY_OP_ABS:
3982                    ggml_sycl_abs(ctx, dst);
3983                    break;
3984                case GGML_UNARY_OP_ELU:
3985                    ggml_sycl_elu(ctx, dst);
3986                    break;
3987                case GGML_UNARY_OP_FLOOR:
3988                    ggml_sycl_floor(ctx, dst);
3989                    break;
3990                case GGML_UNARY_OP_CEIL:
3991                    ggml_sycl_ceil(ctx, dst);
3992                    break;
3993                case GGML_UNARY_OP_ROUND:
3994                    ggml_sycl_round(ctx, dst);
3995                    break;
3996                case GGML_UNARY_OP_TRUNC:
3997                    ggml_sycl_trunc(ctx, dst);
3998                    break;
3999                default:
4000                    return false;
4001            }
4002            break;
4003        case GGML_OP_GLU:
4004            switch (ggml_get_glu_op(dst)) {
4005                case GGML_GLU_OP_REGLU:
4006                    ggml_sycl_reglu(ctx, dst);
4007                    break;
4008                case GGML_GLU_OP_GEGLU:
4009                    ggml_sycl_geglu(ctx, dst);
4010                    break;
4011                case GGML_GLU_OP_SWIGLU:
4012                    ggml_sycl_swiglu(ctx, dst);
4013                    break;
4014                case GGML_GLU_OP_SWIGLU_OAI:
4015                    ggml_sycl_swiglu_oai(ctx, dst);
4016                    break;
4017                case GGML_GLU_OP_GEGLU_ERF:
4018                    ggml_sycl_geglu_erf(ctx, dst);
4019                    break;
4020                case GGML_GLU_OP_GEGLU_QUICK:
4021                    ggml_sycl_geglu_quick(ctx, dst);
4022                    break;
4023                default:
4024                    return false;
4025            }
4026            break;
4027        case GGML_OP_NORM:
4028            ggml_sycl_norm(ctx, dst);
4029            break;
4030        case GGML_OP_GROUP_NORM:
4031            ggml_sycl_group_norm(ctx, dst);
4032            break;
4033        case GGML_OP_CONCAT:
4034            ggml_sycl_op_concat(ctx, dst);
4035            break;
4036        case GGML_OP_PAD_REFLECT_1D:
4037            ggml_sycl_op_pad_reflect_1d(ctx,dst);
4038            break;
4039        case GGML_OP_UPSCALE:
4040            ggml_sycl_upscale(ctx, dst);
4041            break;
4042        case GGML_OP_PAD:
4043            ggml_sycl_pad(ctx, dst);
4044            break;
4045        case GGML_OP_LEAKY_RELU:
4046            ggml_sycl_leaky_relu(ctx, dst);
4047            break;
4048        case GGML_OP_RMS_NORM_BACK:
4049            ggml_sycl_rms_norm_back(ctx, dst);
4050            break;
4051        case GGML_OP_RMS_NORM:
4052            ggml_sycl_rms_norm(ctx, dst);
4053            break;
4054        case GGML_OP_L2_NORM:
4055            ggml_sycl_l2_norm(ctx, dst);
4056            break;
4057        case GGML_OP_MUL_MAT:
4058            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
4059                return false;
4060            }
4061            /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
4062            ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
4063            break;
4064        case GGML_OP_MUL_MAT_ID:
4065            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
4066                return false;
4067            }
4068            ggml_sycl_mul_mat_id(ctx, dst);
4069            break;
4070        case GGML_OP_OUT_PROD:
4071            ggml_sycl_op_out_prod(ctx, dst);
4072            break;
4073        case GGML_OP_SCALE:
4074            ggml_sycl_scale(ctx, dst);
4075            break;
4076        case GGML_OP_SQR:
4077            ggml_sycl_sqr(ctx, dst);
4078            break;
4079        case GGML_OP_SQRT:
4080            ggml_sycl_sqrt(ctx, dst);
4081            break;
4082        case GGML_OP_SIN:
4083            ggml_sycl_sin(ctx, dst);
4084            break;
4085        case GGML_OP_COS:
4086            ggml_sycl_cos(ctx, dst);
4087            break;
4088        case GGML_OP_CLAMP:
4089            ggml_sycl_clamp(ctx, dst);
4090            break;
4091        case GGML_OP_CPY:
4092            ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4093            break;
4094        case GGML_OP_CONT:
4095            ggml_sycl_dup(ctx, dst);
4096            break;
4097        case GGML_OP_NONE:
4098        case GGML_OP_RESHAPE:
4099        case GGML_OP_VIEW:
4100        case GGML_OP_PERMUTE:
4101        case GGML_OP_TRANSPOSE:
4102            GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
4103            break;
4104        case GGML_OP_TRI:
4105            ggml_sycl_op_tri(ctx, dst);
4106            break;
4107        case GGML_OP_DIAG_MASK_INF:
4108            ggml_sycl_diag_mask_inf(ctx, dst);
4109            break;
4110        case GGML_OP_SOFT_MAX:
4111            ggml_sycl_op_soft_max(ctx, dst);
4112            break;
4113        case GGML_OP_SOFT_MAX_BACK:
4114            ggml_sycl_op_soft_max_back(ctx, dst);
4115            break;
4116        case GGML_OP_ROPE:
4117            ggml_sycl_rope(ctx, dst);
4118            break;
4119        case GGML_OP_IM2COL:
4120            ggml_sycl_im2col(ctx, dst);
4121            break;
4122        case GGML_OP_POOL_2D:
4123            ggml_sycl_pool2d(ctx, dst);
4124            break;
4125        case GGML_OP_SUM:
4126            ggml_sycl_sum(ctx, dst);
4127            break;
4128        case GGML_OP_SUM_ROWS:
4129            ggml_sycl_sum_rows(ctx, dst);
4130            break;
4131        case GGML_OP_MEAN:
4132            ggml_sycl_mean(ctx, dst);
4133            break;
4134        case GGML_OP_ARGSORT:
4135            ggml_sycl_argsort(ctx, dst);
4136            break;
4137        case GGML_OP_TOP_K:
4138            ggml_sycl_op_top_k(ctx, dst);
4139            break;
4140        case GGML_OP_TIMESTEP_EMBEDDING:
4141            ggml_sycl_op_timestep_embedding(ctx, dst);
4142            break;
4143        case GGML_OP_RWKV_WKV6:
4144            ggml_sycl_op_rwkv_wkv6(ctx, dst);
4145            break;
4146        case GGML_OP_RWKV_WKV7:
4147            ggml_sycl_op_rwkv_wkv7(ctx, dst);
4148            break;
4149        case GGML_OP_GATED_LINEAR_ATTN:
4150            ggml_sycl_op_gated_linear_attn(ctx, dst);
4151            break;
4152        case GGML_OP_SSM_CONV:
4153            ggml_sycl_ssm_conv(ctx, dst);
4154            break;
4155        case GGML_OP_ROLL:
4156            ggml_sycl_roll(ctx, dst);
4157            break;
4158        case GGML_OP_ARANGE:
4159            ggml_sycl_arange(ctx, dst);
4160            break;
4161        default:
4162            return false;
4163    }
4164
4165    return true;
4166} catch (sycl::exception & e) {
4167    std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
4168    std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
4169    std::exit(1);
4170}
4171
4172GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
4173                                      size_t description_size) try {
4174    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
4175    dpct::device_info prop;
4176    SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
4177        prop, dpct::dev_mgr::instance().get_device(device))));
4178    snprintf(description, description_size, "%s", prop.get_name());
4179}
4180catch (sycl::exception const &exc) {
4181  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4182            << ", line:" << __LINE__ << std::endl;
4183  std::exit(1);
4184}
4185
4186void ggml_backend_sycl_get_device_memory(int device, size_t *free,
4187                                                   size_t *total) try {
4188    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
4189    ggml_sycl_set_device(device);
4190
4191    SYCL_CHECK(CHECK_TRY_ERROR(
4192        dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
4193}
4194catch (sycl::exception const &exc) {
4195  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4196            << ", line:" << __LINE__ << std::endl;
4197  std::exit(1);
4198}
4199
4200////////////////////////////////////////////////////////////////////////////////
4201
4202// backend
4203
4204static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
4205
4206    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4207
4208    return sycl_ctx->name.c_str();
4209}
4210
4211static void ggml_backend_sycl_free(ggml_backend_t backend) {
4212    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4213
4214    delete sycl_ctx;
4215    delete backend;
4216}
4217
4218static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
4219                                               ggml_tensor *tensor,
4220                                               const void *data, size_t offset,
4221                                               size_t size) try {
4222    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4223    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
4224    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4225    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4226    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4227
4228    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4229    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4230    SYCL_CHECK(CHECK_TRY_ERROR(
4231        (stream)->memcpy((char *)tensor->data + offset, data, size)));
4232}
4233catch (sycl::exception const &exc) {
4234  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4235            << ", line:" << __LINE__ << std::endl;
4236  std::exit(1);
4237}
4238
4239static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
4240                                               const ggml_tensor *tensor,
4241                                               void *data, size_t offset,
4242                                               size_t size) try {
4243    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4244    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
4245    GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4246    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4247    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4248
4249    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4250    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4251    SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4252        data, (const char *)tensor->data + offset, size)));
4253}
4254catch (sycl::exception const &exc) {
4255  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4256            << ", line:" << __LINE__ << std::endl;
4257  std::exit(1);
4258}
4259
4260static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4261                                               const ggml_tensor *src,
4262                                               ggml_tensor *dst) try {
4263    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4264    bool is_cpy_supported                = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
4265                            ggml_backend_buffer_is_sycl(src->buffer);
4266    GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4267    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
4268    GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
4269    GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
4270    if (is_cpy_supported) {
4271        /*
4272        DPCT1009:215: SYCL uses exceptions to report errors and does not use the
4273        error codes. The original code was commented out and a warning string
4274        was inserted. You need to rewrite this code.
4275        */
4276        const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4277        SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4278            dst->data, src->data, ggml_nbytes(dst))));
4279        return true;
4280    }
4281
4282    return false;
4283}
4284catch (sycl::exception const &exc) {
4285  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4286            << ", line:" << __LINE__ << std::endl;
4287  std::exit(1);
4288}
4289
4290static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
4291    GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4292    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4293    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4294    SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
4295
4296    GGML_UNUSED(backend);
4297}
4298catch (sycl::exception const &exc) {
4299  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4300            << ", line:" << __LINE__ << std::endl;
4301  std::exit(1);
4302}
4303
4304static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
4305    ggml_sycl_set_main_device(sycl_ctx->device);
4306
4307    for (int i = 0; i < cgraph->n_nodes; i++) {
4308        ggml_tensor * node = cgraph->nodes[i];
4309        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
4310            continue;
4311        }
4312        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
4313            continue;
4314        }
4315#ifndef NDEBUG
4316        assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
4317        for (int j = 0; j < GGML_MAX_SRC; j++) {
4318            if (node->src[j] != nullptr) {
4319                assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
4320            }
4321        }
4322#endif
4323        bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
4324        if (!ok) {
4325            GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
4326        }
4327        GGML_ASSERT(ok);
4328    }
4329}
4330
4331#ifdef GGML_SYCL_GRAPH
4332static bool check_graph_compatibility(ggml_cgraph * cgraph) {
4333    if (ggml_sycl_info().device_count > 1) {
4334        // A sycl_ex::command_graph object can only be created for a single device
4335        GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
4336        return false;
4337    }
4338
4339    for (int i = 0; i < cgraph->n_nodes; i++) {
4340        const ggml_op node_op = cgraph->nodes[i]->op;
4341        switch (node_op) {
4342            default:
4343                break;
4344            case GGML_OP_CONCAT:
4345                // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
4346                // but wait() can't be called on the events returned by a queue recording
4347                // to a graph.
4348                [[fallthrough]];
4349            case GGML_OP_MUL_MAT_ID:
4350                // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
4351                // submitting a memcpy operation, but wait() can't be called on a queue that
4352                // is recording to a graph.
4353                GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
4354                              ggml_op_name(node_op));
4355                return false;
4356            case GGML_OP_MUL_MAT:
4357                // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4358                // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4359                // in reordering.
4360                if (!g_ggml_sycl_use_async_mem_op) {
4361                    GGML_LOG_INFO(
4362                        "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4363                        "oneAPI async memory allocation extension "
4364                        "%s\n",
4365                        __func__, ggml_op_name(node_op));
4366                    return false;
4367                }
4368        }
4369    }
4370    return true;
4371}
4372#endif
4373
4374static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4375    auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
4376
4377#ifdef GGML_SYCL_GRAPH
4378    bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
4379    if (use_sycl_graph) {
4380        const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
4381        if (!graph_support) {
4382            GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
4383            ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4384            return GGML_STATUS_SUCCESS;
4385        }
4386
4387        sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
4388
4389        model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
4390        ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4391        model_sycl_graph.end_recording();
4392
4393        const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
4394        if (!sycl_ctx->exec_graph || !graph_update_support) {
4395            auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
4396                                                     model_sycl_graph.finalize();
4397            sycl_ctx->exec_graph = std::make_unique<
4398                sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4399        } else {
4400            try {
4401                sycl_ctx->exec_graph->update(model_sycl_graph);
4402                GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
4403            } catch (sycl::exception const & e) {
4404                GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
4405                auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
4406                sycl_ctx->exec_graph = std::make_unique<
4407                    sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4408            }
4409        }
4410
4411        sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
4412    } else
4413#endif
4414    {
4415        ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4416    }
4417    return GGML_STATUS_SUCCESS;
4418}
4419
4420static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
4421try
4422{
4423    ggml_backend_sycl_context *sycl_ctx =
4424        (ggml_backend_sycl_context *)backend->context;
4425
4426    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4427
4428    const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
4429    // Record the current state of the queue
4430    SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
4431}
4432catch (sycl::exception const &exc)
4433{
4434    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4435              << ", line:" << __LINE__ << std::endl;
4436    std::exit(1);
4437}
4438
4439static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
4440    GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4441    sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
4442
4443    if (ggml_backend_is_sycl(backend)) {
4444        SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
4445    } else
4446        GGML_ABORT("fatal error");
4447} catch (sycl::exception const& exc) {
4448    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4449              << ", line:" << __LINE__ << std::endl;
4450    std::exit(1);
4451}
4452
4453static ggml_backend_i ggml_backend_sycl_interface = {
4454    /* .get_name                = */ ggml_backend_sycl_get_name,
4455    /* .free                    = */ ggml_backend_sycl_free,
4456    /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,
4457    /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,
4458    /* .cpy_tensor_async        = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
4459                                           // // TODO: update for the new
4460                                           // interface
4461    /* .synchronize             = */ ggml_backend_sycl_synchronize,
4462    /* .graph_plan_create       = */ NULL,
4463    /* .graph_plan_free         = */ NULL,
4464    /* .graph_plan_update       = */ NULL,
4465    /* .graph_plan_compute      = */ NULL,
4466    /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
4467    /* .event_record            = */ ggml_backend_sycl_event_record,
4468    /* .event_wait              = */ ggml_backend_sycl_event_wait,
4469    /* .graph_optimize          = */ NULL,
4470};
4471
4472static ggml_guid_t ggml_backend_sycl_guid() {
4473    static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
4474    return &guid;
4475}
4476
4477bool ggml_backend_is_sycl(ggml_backend_t backend) {
4478    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
4479}
4480
4481int ggml_backend_sycl_get_device_count() {
4482    return ggml_sycl_info().device_count;
4483}
4484
4485
4486// backend device
4487
4488struct ggml_backend_sycl_device_context {
4489    int device;
4490    std::string name;
4491    std::string description;
4492    int op_offload_min_batch_size;
4493};
4494
4495static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
4496    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4497    return ctx->name.c_str();
4498}
4499
4500static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
4501    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4502    return ctx->description.c_str();
4503}
4504
4505static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
4506    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4507    ggml_sycl_set_device(ctx->device);
4508    SYCL_CHECK(CHECK_TRY_ERROR(
4509    dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
4510}
4511
4512static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
4513    GGML_UNUSED(dev);
4514    return GGML_BACKEND_DEVICE_TYPE_GPU;
4515}
4516
4517static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
4518    props->name        = ggml_backend_sycl_device_get_name(dev);
4519    props->description = ggml_backend_sycl_device_get_description(dev);
4520    props->type        = ggml_backend_sycl_device_get_type(dev);
4521    ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
4522
4523    bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
4524#ifdef GGML_SYCL_NO_PEER_COPY
4525    bool events = false;
4526#else
4527    bool events = true;
4528#endif
4529
4530    props->caps = {
4531        /* .async                 = */ true,
4532        /* .host_buffer           = */ host_buffer,
4533        /* .buffer_from_host_ptr  = */ false,
4534        /* .events                = */ events,
4535    };
4536}
4537
4538static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
4539    GGML_UNUSED(params);
4540    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4541    return ggml_backend_sycl_init(ctx->device);
4542}
4543
4544static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
4545    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4546    return ggml_backend_sycl_buffer_type(ctx->device);
4547}
4548
4549static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
4550    GGML_UNUSED(dev);
4551    return ggml_backend_sycl_host_buffer_type();
4552}
4553
4554static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
4555    GGML_UNUSED(dev);
4556    GGML_UNUSED(ptr);
4557    GGML_UNUSED(size);
4558    GGML_UNUSED(max_tensor_size);
4559    return nullptr;
4560}
4561
4562static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4563    ggml_backend_sycl_device_context *sycl_ctx =
4564        (ggml_backend_sycl_device_context *)dev->context;
4565    int device = sycl_ctx->device;
4566    switch (op->op) {
4567        case GGML_OP_CONV_TRANSPOSE_1D:
4568            {
4569                ggml_type src0_type = op->src[0]->type;
4570                ggml_type src1_type = op->src[1]->type;
4571                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4572                    return true;
4573                }
4574                return false;
4575            }
4576        case GGML_OP_UNARY:
4577            switch (ggml_get_unary_op(op)) {
4578                case GGML_UNARY_OP_SGN:
4579                case GGML_UNARY_OP_ABS:
4580                case GGML_UNARY_OP_NEG:
4581                case GGML_UNARY_OP_STEP:
4582                case GGML_UNARY_OP_RELU:
4583                case GGML_UNARY_OP_HARDSIGMOID:
4584                case GGML_UNARY_OP_TANH:
4585                case GGML_UNARY_OP_GELU:
4586                case GGML_UNARY_OP_SILU:
4587                case GGML_UNARY_OP_SIGMOID:
4588                case GGML_UNARY_OP_HARDSWISH:
4589                case GGML_UNARY_OP_GELU_QUICK:
4590                case GGML_UNARY_OP_GELU_ERF:
4591                case GGML_UNARY_OP_EXP:
4592                case GGML_UNARY_OP_SOFTPLUS:
4593                case GGML_UNARY_OP_ELU:
4594                case GGML_UNARY_OP_CEIL:
4595                    return true;
4596                case GGML_UNARY_OP_FLOOR:
4597                case GGML_UNARY_OP_ROUND:
4598                case GGML_UNARY_OP_TRUNC:
4599#if defined (GGML_SYCL_F16)
4600                    return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4601#else
4602                    return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4603#endif
4604                default:
4605                    return false;
4606            }
4607        case GGML_OP_GLU:
4608            switch (ggml_get_glu_op(op)) {
4609                case GGML_GLU_OP_REGLU:
4610                case GGML_GLU_OP_GEGLU:
4611                case GGML_GLU_OP_SWIGLU:
4612                case GGML_GLU_OP_SWIGLU_OAI:
4613                case GGML_GLU_OP_GEGLU_ERF:
4614                case GGML_GLU_OP_GEGLU_QUICK:
4615                    return ggml_is_contiguous_1(op->src[0]);
4616                default:
4617                    return false;
4618            }
4619            break;
4620        case GGML_OP_MUL_MAT:
4621        case GGML_OP_MUL_MAT_ID:
4622            {
4623                struct ggml_tensor * a = op->src[0];
4624                struct ggml_tensor * b = op->src[1];
4625
4626                if (a->ne[3] != b->ne[3]) {
4627                    return false;
4628                }
4629                ggml_type a_type = a->type;
4630                if (a_type == GGML_TYPE_IQ4_NL  || a_type == GGML_TYPE_IQ4_XS ||
4631                    a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S  ||
4632                    a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
4633                    a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
4634                    ) {
4635                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
4636                        return false;
4637                    }
4638                }
4639                ggml_type src0_type = op->src[0]->type;
4640                if (src0_type == GGML_TYPE_BF16 ) {
4641                    // TODO: support GGML_TYPE_BF16
4642                    // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4643                    return false;
4644                }
4645
4646                // TODO: The configuration below needs more work to be supported with oneDNN
4647                if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
4648                    a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
4649                  return false;
4650                }
4651
4652                // TODO: This specific configuration can fail with oneDNN and needs more debugging
4653                if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4654                    a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
4655                    return false;
4656                }
4657                return true;
4658            }
4659        case GGML_OP_OUT_PROD:
4660            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4661        case GGML_OP_GET_ROWS:
4662            {
4663                switch (op->src[0]->type) {
4664                    case GGML_TYPE_F16:
4665                    case GGML_TYPE_F32:
4666                    case GGML_TYPE_Q4_0:
4667                    case GGML_TYPE_Q4_1:
4668                    case GGML_TYPE_Q5_0:
4669                    case GGML_TYPE_Q5_1:
4670                    case GGML_TYPE_Q8_0:
4671                        return true;
4672                    default:
4673                        return false;
4674                }
4675            }
4676         case GGML_OP_SET:
4677               return (op->type == GGML_TYPE_F32) &&
4678                      (op->src[0] && op->src[1]) &&
4679                      (op->src[0]->type == GGML_TYPE_F32) &&
4680                      (op->src[1]->type == GGML_TYPE_F32);
4681
4682        case GGML_OP_SET_ROWS:
4683            {
4684                return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4685                         op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
4686                         op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
4687                        (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
4688            }
4689            break;
4690        case GGML_OP_CPY:
4691            {
4692                ggml_type src0_type = op->src[0]->type;
4693                ggml_type src1_type = op->src[1]->type;
4694                if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
4695                    return true;
4696                }
4697                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4698                    return true;
4699                }
4700                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
4701                    return true;
4702                }
4703                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
4704                    return true;
4705                }
4706                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
4707                    return true;
4708                }
4709                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
4710                    return true;
4711                }
4712                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
4713                    return true;
4714                }
4715                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4716                    return true;
4717                }
4718                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4719                    return true;
4720                }
4721                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4722                    return true;
4723                }
4724                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4725                    return true;
4726                }
4727                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4728                    return true;
4729                }
4730                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4731                    return true;
4732                }
4733                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4734                    return true;
4735                }
4736                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4737                    return true;
4738                }
4739                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4740                    return true;
4741                }
4742                if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
4743                    return true;
4744                }
4745                if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
4746                    return true;
4747                }
4748                if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
4749                    return true;
4750                }
4751                if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
4752                    return true;
4753                }
4754                if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
4755                    return true;
4756                }
4757                return false;
4758            }
4759        case GGML_OP_REPEAT_BACK:
4760            {
4761                ggml_type src0_type = op->src[0]->type;
4762                return src0_type == GGML_TYPE_F32;
4763            }
4764        case GGML_OP_CONCAT:
4765        case GGML_OP_DUP:
4766        case GGML_OP_ARGMAX:
4767        case GGML_OP_NONE:
4768        case GGML_OP_RESHAPE:
4769        case GGML_OP_VIEW:
4770        case GGML_OP_PERMUTE:
4771        case GGML_OP_TRANSPOSE:
4772        case GGML_OP_ADD:
4773        case GGML_OP_ADD1:
4774        case GGML_OP_ADD_ID:
4775        case GGML_OP_SUB:
4776        case GGML_OP_COUNT_EQUAL:
4777        case GGML_OP_MUL:
4778        case GGML_OP_DIV:
4779        case GGML_OP_REPEAT:
4780            return true;
4781        case GGML_OP_PAD_REFLECT_1D:
4782            return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
4783        case GGML_OP_SQR:
4784        case GGML_OP_SQRT:
4785        case GGML_OP_SIN:
4786        case GGML_OP_COS:
4787        case GGML_OP_CLAMP:
4788        case GGML_OP_LOG:
4789#if defined (GGML_SYCL_F16)
4790            return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
4791#else
4792            return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4793#endif
4794        case GGML_OP_NORM:
4795        case GGML_OP_L2_NORM:
4796        case GGML_OP_GROUP_NORM:
4797        case GGML_OP_RMS_NORM:
4798            return true;
4799        case GGML_OP_RMS_NORM_BACK:
4800            return ggml_is_contiguous(op->src[0]);
4801        case GGML_OP_SCALE:
4802            return true;
4803        case GGML_OP_CONT:
4804            return op->src[0]->type != GGML_TYPE_BF16;
4805        case GGML_OP_TRI:
4806            {
4807                const ggml_tensor * src0 = op->src[0];
4808                return src0 &&
4809                       op->type == GGML_TYPE_F32 &&
4810                       ggml_is_contiguous(src0);
4811            }
4812        case GGML_OP_DIAG_MASK_INF:
4813            return true;
4814        case GGML_OP_SOFT_MAX:
4815            return true;
4816        case GGML_OP_SOFT_MAX_BACK: {
4817            float max_bias = 0.0f;
4818            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4819            return max_bias == 0.0f;
4820        }
4821        case GGML_OP_ROPE:
4822        case GGML_OP_IM2COL:
4823            return true;
4824        case GGML_OP_UPSCALE:
4825            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4826        case GGML_OP_SUM:
4827        case GGML_OP_SUM_ROWS:
4828        case GGML_OP_MEAN:
4829            return ggml_is_contiguous(op->src[0]);
4830        case GGML_OP_ARGSORT:
4831            return op->src[0]->ne[0] * sizeof(int) <=
4832                   ggml_sycl_info().devices[device].smpbo;
4833        case GGML_OP_TOP_K: {
4834            const ggml_tensor * src0 = op->src[0];
4835            const int k = op->ne[0];
4836            return src0 &&
4837                op->type == GGML_TYPE_I32 &&
4838                src0->type == GGML_TYPE_F32 &&
4839                ggml_is_contiguous(src0) &&
4840                k > 0 && k <= 32;
4841        }
4842        case GGML_OP_POOL_2D:
4843        case GGML_OP_ACC:
4844            return true;
4845        case GGML_OP_PAD:
4846            // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4847            if (ggml_get_op_params_i32(op, 8) != 0) {
4848                return false;
4849            }
4850            return ggml_is_contiguous(op->src[0]);
4851        case GGML_OP_LEAKY_RELU:
4852        case GGML_OP_TIMESTEP_EMBEDDING:
4853        case GGML_OP_RWKV_WKV6:
4854        case GGML_OP_RWKV_WKV7:
4855        case GGML_OP_GATED_LINEAR_ATTN:
4856            return true;
4857        case GGML_OP_SSM_CONV:
4858            return op->type == GGML_TYPE_F32 &&
4859                   op->src[0]->type == GGML_TYPE_F32 &&
4860                   op->src[1]->type == GGML_TYPE_F32;
4861        case GGML_OP_ROLL:
4862            return op->type == GGML_TYPE_F32;
4863        case GGML_OP_ARANGE:
4864            return op->type == GGML_TYPE_F32;
4865        default:
4866            return false;
4867    }
4868
4869    GGML_UNUSED(dev);
4870}
4871
4872static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
4873    if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
4874        return false;
4875    }
4876    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
4877    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4878    return buft_ctx->device == sycl_ctx->device;
4879}
4880
4881static int64_t get_op_batch_size(const ggml_tensor * op) {
4882    switch (op->op) {
4883        case GGML_OP_GET_ROWS:
4884            return 0;
4885        case GGML_OP_MUL_MAT:
4886            return op->ne[1];
4887        case GGML_OP_MUL_MAT_ID:
4888        case GGML_OP_ROPE:
4889            return op->ne[2];
4890        default:
4891            return ggml_nrows(op);
4892    }
4893}
4894
4895static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4896    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4897    return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
4898}
4899
4900static ggml_backend_event_t
4901ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
4902
4903#ifdef GGML_SYCL_NO_PEER_COPY
4904    return nullptr;
4905#else
4906  sycl::event *event_ptr = new sycl::event();
4907
4908  return new ggml_backend_event{
4909      /* .device = */ dev,
4910      /* .context = */ event_ptr,
4911  };
4912#endif
4913}
4914
4915static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4916  GGML_UNUSED(dev);
4917  if (event == nullptr) {
4918    return;
4919  }
4920
4921  if (event->context != nullptr) {
4922    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4923    delete sycl_event;
4924    event->context = nullptr;
4925  }
4926
4927  delete event;
4928} catch (sycl::exception const &exc) {
4929  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4930            << ", line:" << __LINE__ << std::endl;
4931  std::exit(1);
4932}
4933
4934
4935static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4936  GGML_UNUSED(dev);
4937  GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4938
4939  sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4940  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
4941} catch (sycl::exception const &exc) {
4942  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4943            << ", line:" << __LINE__ << std::endl;
4944  std::exit(1);
4945}
4946
4947static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
4948    /* .get_name                = */ ggml_backend_sycl_device_get_name,
4949    /* .get_description         = */ ggml_backend_sycl_device_get_description,
4950    /* .get_memory              = */ ggml_backend_sycl_device_get_memory,
4951    /* .get_type                = */ ggml_backend_sycl_device_get_type,
4952    /* .get_props               = */ ggml_backend_sycl_device_get_props,
4953    /* .init_backend            = */ ggml_backend_sycl_device_init,
4954    /* .get_buffer_type         = */ ggml_backend_sycl_device_get_buffer_type,
4955    /* .get_host_buffer_type    = */ ggml_backend_sycl_device_get_host_buffer_type,
4956    /* .buffer_from_host_ptr    = */ ggml_backend_sycl_device_buffer_from_host_ptr,
4957    /* .supports_op             = */ ggml_backend_sycl_device_supports_op,
4958    /* .supports_buft           = */ ggml_backend_sycl_device_supports_buft,
4959    /* .offload_op              = */ ggml_backend_sycl_device_offload_op,
4960    /* .event_new               = */ ggml_backend_sycl_device_event_new,
4961    /* .event_free              = */ ggml_backend_sycl_device_event_free,
4962    /* .event_synchronize       = */ ggml_backend_sycl_device_event_synchronize,
4963};
4964
4965// backend reg
4966
4967struct ggml_backend_sycl_reg_context {
4968    std::vector<ggml_backend_dev_t> devices;
4969};
4970
4971static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
4972    GGML_UNUSED(reg);
4973    return GGML_SYCL_NAME;
4974}
4975
4976static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
4977    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
4978    return ctx->devices.size();
4979}
4980
4981static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
4982    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
4983    GGML_ASSERT(index < ctx->devices.size());
4984    return ctx->devices[index];
4985}
4986
4987static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
4988    GGML_UNUSED(reg);
4989
4990    if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4991        return (void *)ggml_backend_sycl_split_buffer_type;
4992    }
4993
4994    // SYCL doesn't support registering host memory, left here for reference
4995    // "ggml_backend_register_host_buffer"
4996    // "ggml_backend_unregister_host_buffer"
4997    GGML_UNUSED(name);
4998    return nullptr;
4999}
5000
5001static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
5002    /* .get_name          = */ ggml_backend_sycl_reg_get_name,
5003    /* .get_device_count  = */ ggml_backend_sycl_reg_get_device_count,
5004    /* .get_device        = */ ggml_backend_sycl_reg_get_device,
5005    /* .get_proc_address  = */ ggml_backend_sycl_reg_get_proc_address,
5006};
5007
5008
5009// backend registry
5010
5011ggml_backend_reg_t ggml_backend_sycl_reg() {
5012    static ggml_backend_reg reg;
5013    static bool initialized = false;
5014
5015    {
5016        static std::mutex mutex;
5017        std::lock_guard<std::mutex> lock(mutex);
5018        if (!initialized) {
5019            ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
5020            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
5021
5022            for (int i = 0; i < ggml_sycl_info().device_count; i++) {
5023                ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
5024                dev_ctx->device = i;
5025                dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
5026
5027                ggml_sycl_set_device(i);
5028
5029                dpct::device_info prop;
5030                SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
5031                    prop, dpct::dev_mgr::instance().get_device(i))));
5032
5033                dev_ctx->description = prop.get_name();
5034                dev_ctx->op_offload_min_batch_size = min_batch_size;
5035
5036                ggml_backend_dev_t dev = new ggml_backend_device {
5037                    /* .iface       = */ ggml_backend_sycl_device_interface,
5038                    /* .reg         = */ &reg,
5039                    /* .context     = */ dev_ctx
5040                };
5041                ctx->devices.push_back(dev);
5042            }
5043
5044            reg = ggml_backend_reg {
5045                /* .api_version = */ GGML_BACKEND_API_VERSION,
5046                /* .iface       = */ ggml_backend_sycl_reg_interface,
5047                /* .context     = */ ctx
5048            };
5049        }
5050
5051        initialized = true;
5052    }
5053
5054    return &reg;
5055}
5056
5057ggml_backend_t ggml_backend_sycl_init(int device) {
5058    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
5059    ggml_check_sycl();
5060
5061    check_allow_gpu_index(device);
5062
5063    ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
5064    if (ctx == nullptr) {
5065        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
5066        return nullptr;
5067    };
5068
5069    ggml_backend_t sycl_backend = new ggml_backend {
5070        /* .guid    = */ ggml_backend_sycl_guid(),
5071        /* .iface   = */ ggml_backend_sycl_interface,
5072        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
5073        /* .context = */ ctx
5074    };
5075
5076    return sycl_backend;
5077}
5078
5079GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg)