1//
2// MIT license
3// Copyright (C) 2024 Intel Corporation
4// SPDX-License-Identifier: MIT
5//
6
7//
8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9// See https://llvm.org/LICENSE.txt for license information.
10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11//
12
13#include <algorithm>
14#include <assert.h>
15#include <atomic>
16#include <cinttypes>
17#include <cstddef>
18#include <cstdint>
19#include <cstdlib>
20#include <float.h>
21#include <limits>
22#include <stdint.h>
23#include <stdio.h>
24#include <vector>
25#include <cmath>
26#include <iostream>
27#include <fstream>
28#include <stdio.h>
29#include <stdlib.h>
30#include <regex>
31
32#include <sycl/sycl.hpp>
33#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
35#endif
36#include <sycl/half_type.hpp>
37
38#include "ggml-sycl.h"
39#include "ggml-impl.h"
40#include "ggml-backend-impl.h"
41
42#include "ggml-sycl/add-id.hpp"
43#include "ggml-sycl/backend.hpp"
44#include "ggml-sycl/common.hpp"
45#include "ggml-sycl/element_wise.hpp"
46#include "ggml-sycl/norm.hpp"
47#include "ggml-sycl/presets.hpp"
48#include "ggml-sycl/gemm.hpp"
49#include "ggml-sycl/set_rows.hpp"
50#include "ggml-sycl/set.hpp"
51#include "ggml-sycl/sycl_hw.hpp"
52#include "ggml-sycl/getrows.hpp"
53#include "ggml-sycl/repeat_back.hpp"
54#include "ggml-sycl/quantize.hpp"
55#include "ggml-sycl/ssm_conv.hpp"
56#include "ggml.h"
57
58static bool g_sycl_loaded = false;
59int g_ggml_sycl_debug = 0;
60int g_ggml_sycl_disable_optimize = 0;
61int g_ggml_sycl_disable_graph = 0;
62int g_ggml_sycl_disable_dnn = 0;
63int g_ggml_sycl_prioritize_dmmv = 0;
64int g_ggml_sycl_use_async_mem_op = 0;
65
66static ggml_sycl_device_info ggml_sycl_init() {
67 ggml_sycl_device_info info = {};
68
69 info.device_count = dpct::dev_mgr::instance().device_count();
70 if (info.device_count == 0) {
71 GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__);
72 return info;
73 }
74
75 GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
76
77 int64_t total_vram = 0;
78/* This is a bit misleading; reserved for later */
79// #if defined(SYCL_USE_XMX)
80// GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
81// #else
82// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
83// #endif
84 for (int i = 0; i < info.device_count; ++i) {
85 info.devices[i].vmm = 0;
86 dpct::device_info prop;
87 sycl::device device = dpct::dev_mgr::instance().get_device(i);
88
89 SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
90 prop, device)));
91
92 info.default_tensor_split[i] = total_vram;
93 total_vram += prop.get_global_mem_size();
94
95 info.devices[i].cc =
96 100 * prop.get_major_version() + 10 * prop.get_minor_version();
97 info.devices[i].nsm = prop.get_max_compute_units();
98 info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
99 info.devices[i].smpbo = prop.get_local_mem_size();
100
101 info.max_work_group_sizes[i] = prop.get_max_work_group_size();
102 }
103
104 for (int id = 0; id < info.device_count; ++id) {
105 info.default_tensor_split[id] /= total_vram;
106 }
107 return info;
108}
109
110const ggml_sycl_device_info & ggml_sycl_info() {
111 static ggml_sycl_device_info info = ggml_sycl_init();
112 return info;
113}
114
115static void print_device_detail(int id, sycl::device &device, std::string device_type) {
116
117 dpct::device_info prop;
118 SYCL_CHECK(CHECK_TRY_ERROR(
119 dpct::get_device_info(prop, device)));
120
121 std::string version;
122 version += std::to_string(prop.get_major_version());
123 version += ".";
124 version += std::to_string(prop.get_minor_version());
125
126 device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
127 std::string name = std::string(prop.get_name());
128 name = std::regex_replace(name, std::regex("\\(R\\)"), "");
129 name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
130
131 auto global_mem_size = prop.get_global_mem_size()/1000000;
132 GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
133 name.c_str(), version.c_str(), prop.get_max_compute_units(),
134 prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
135 global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
136}
137
138static void print_device_opt_feature(int device_count) {
139 GGML_LOG_INFO("SYCL Optimization Feature:\n");
140 GGML_LOG_INFO(
141 "|ID| Device Type|Reorder|\n");
142 GGML_LOG_INFO(
143 "|--|-------------------|-------|\n");
144 std::map<std::string, size_t> DeviceNums;
145 for (int id = 0; id < device_count; ++id) {
146 sycl::device device = dpct::dev_mgr::instance().get_device(id);
147 std::string backend_type = get_device_backend_and_type(device);
148 int type_id = DeviceNums[backend_type]++;
149 std::stringstream device_type;
150 device_type << "[" << backend_type << ":" << std::to_string(type_id)
151 << "]";
152 std::string device_type_s = device_type.str();
153 device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
154 GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
155 ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
156 }
157
158}
159void ggml_backend_sycl_print_sycl_devices() {
160 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
161 int device_count = dpct::dev_mgr::instance().device_count();
162 std::map<std::string, size_t> DeviceNums;
163 GGML_LOG_INFO("Found %d SYCL devices:\n", device_count);
164
165 GGML_LOG_INFO(
166 "| | | | "
167 " |Max | |Max |Global | |\n");
168 GGML_LOG_INFO(
169 "| | | | "
170 " |compute|Max work|sub |mem | |\n");
171 GGML_LOG_INFO(
172 "|ID| Device Type| "
173 "Name|Version|units |group |group|size | Driver version|\n");
174 GGML_LOG_INFO(
175 "|--|-------------------|---------------------------------------|------"
176 "-|-------|--------|-----|-------|---------------------|\n");
177
178 for (int id = 0; id < device_count; ++id) {
179 sycl::device device = dpct::dev_mgr::instance().get_device(id);
180 std::string backend_type = get_device_backend_and_type(device);
181 int type_id = DeviceNums[backend_type]++;
182 std::stringstream device_type;
183 device_type << "[" << backend_type << ":" << std::to_string(type_id)
184 << "]";
185 print_device_detail(id, device, device_type.str());
186 }
187
188 print_device_opt_feature(device_count);
189}
190
191static inline int get_sycl_env(const char *env_name, int default_val) {
192 char *user_device_string = getenv(env_name);
193 int user_number = default_val;
194
195 unsigned n;
196 if (user_device_string != NULL &&
197 sscanf(user_device_string, " %u", &n) == 1) {
198 user_number = (int)n;
199 } else {
200 user_number = default_val;
201 }
202 return user_number;
203}
204
205static void ggml_check_sycl() try {
206 static bool initialized = false;
207
208 if (!initialized) {
209 g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
210 g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
211 g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
212 g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
213 g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
214 GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
215 GGML_LOG_INFO("Running with Environment Variables:\n");
216 GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
217 GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
218#ifdef GGML_SYCL_GRAPH
219 GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
220#else
221 GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
222#endif
223#if GGML_SYCL_DNNL
224 GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
225#else
226 GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
227#endif
228 GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
229 GGML_LOG_INFO("Build with Macros:\n");
230#if defined(GGML_SYCL_FORCE_MMQ)
231 GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
232#else
233 GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
234#endif
235#if defined(GGML_SYCL_F16)
236 GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
237#else
238 GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
239#endif
240
241/* NOT REMOVE, keep it for next optimize for XMX.
242#if defined(SYCL_USE_XMX)
243 fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
244#else
245 fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
246#endif
247*/
248 // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
249 // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
250 // other places.
251#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
252 g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
253 if (g_ggml_sycl_use_async_mem_op) {
254 for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
255 if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
256 g_ggml_sycl_use_async_mem_op = 0;
257 break;
258 }
259 }
260 }
261#endif
262 if (CHECK_TRY_ERROR(g_all_sycl_device_count =
263 dpct::dev_mgr::instance().device_count()) != 0) {
264 initialized = true;
265 g_sycl_loaded = false;
266 return;
267 }
268 GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
269
270 initialized = true;
271 g_sycl_loaded = true;
272 ggml_backend_sycl_print_sycl_devices();
273 }
274}
275catch (sycl::exception const &exc) {
276 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
277 << ", line:" << __LINE__ << std::endl;
278 std::exit(1);
279}
280
281/*
282device_index: device index from 0 to n (continue numbers).
283 It is used for device select/set in SYCL backend internal data structure.
284*/
285inline void check_allow_gpu_index(const int device_index) {
286 if (device_index >= ggml_sycl_info().device_count) {
287 char error_buf[256];
288 snprintf(
289 error_buf,
290 sizeof(error_buf),
291 "%s error: device_index:%d is out of range: [0-%d]",
292 __func__,
293 device_index,
294 ggml_sycl_info().device_count - 1);
295 GGML_LOG_ERROR("%s\n", error_buf);
296 assert(false);
297 }
298}
299
300GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
301 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
302 for(int i=0;i<max_len;i++) id_list[i] = -1;
303
304 for (int i=0;i< ggml_sycl_info().device_count;i++){
305 if (i>=max_len) break;
306 id_list[i] = i;
307 }
308 return;
309}
310catch (sycl::exception const &exc) {
311 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
312 << ", line:" << __LINE__ << std::endl;
313 std::exit(1);
314}
315
316// sycl buffer
317
318struct ggml_backend_sycl_buffer_context {
319 int device;
320 void * dev_ptr = nullptr;
321 queue_ptr stream;
322 std::string name;
323 optimize_feature opt_feature;
324 std::vector<ggml_tensor_extra_gpu *> tensor_extras;
325
326 ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
327 device(device), dev_ptr(dev_ptr), stream(stream) {
328 check_allow_gpu_index(device);
329 name = (GGML_SYCL_NAME + std::to_string(device));
330 opt_feature = ggml_sycl_info().devices[device].opt_feature;
331 }
332
333 ~ggml_backend_sycl_buffer_context() {
334 if (dev_ptr != nullptr) {
335 ggml_sycl_set_device(device);
336 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
337 }
338
339 //release extra used by tensors
340 for (ggml_tensor_extra_gpu * extra : tensor_extras) {
341 release_extra_gpu(extra);
342 }
343
344 }
345};
346
347static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);
348
349static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
350 return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;
351}
352
353static void
354ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
355 ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
356 ggml_sycl_set_device(ctx->device);
357
358 delete ctx;
359}
360catch (sycl::exception const &exc) {
361 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
362 << ", line:" << __LINE__ << std::endl;
363 std::exit(1);
364}
365
366static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
367 ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
368 return ctx->dev_ptr;
369}
370
371static enum ggml_status
372ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
373 ggml_tensor *tensor) try {
374 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
375 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
376 ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
377
378 if (tensor->view_src != NULL) {
379 assert(tensor->view_src->buffer->buft == buffer->buft);
380 return GGML_STATUS_SUCCESS;
381 }
382 if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
383 !g_ggml_sycl_disable_optimize) {
384 ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
385 tensor->extra = extra;
386 ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
387 }
388
389 if (ggml_is_quantized(tensor->type)) {
390 // initialize padding to 0 to avoid possible NaN values
391 size_t original_size = ggml_nbytes(tensor);
392 size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
393
394 if (padded_size > original_size && tensor->view_src == nullptr) {
395 SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
396 (char *)tensor->data + original_size, 0,
397 padded_size - original_size).wait()));
398 }
399 }
400 return GGML_STATUS_SUCCESS;
401}
402catch (sycl::exception const &exc) {
403 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
404 << ", line:" << __LINE__ << std::endl;
405 std::exit(1);
406}
407
408static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
409 ggml_tensor *tensor,
410 const void *data, size_t offset,
411 size_t size) try {
412 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
413 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
414 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
415 ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
416 ggml_sycl_set_device(ctx->device);
417 auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
418 SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
419#ifndef _WIN32
420 // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
421 // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
422 char * host_buf = (char *) malloc(size);
423 memcpy(host_buf, data, size);
424 SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
425 free(host_buf);
426#else
427 SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
428#endif
429}
430catch (sycl::exception const &exc) {
431 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
432 << ", line:" << __LINE__ << std::endl;
433 std::exit(1);
434}
435
436static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
437 const ggml_tensor *tensor,
438 void *data, size_t offset,
439 size_t size) try {
440 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
441 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
442 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
443 ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
444
445 ggml_sycl_set_device(ctx->device);
446 auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
447
448 SYCL_CHECK(CHECK_TRY_ERROR(
449 stream.memcpy(data, (const char *)tensor->data + offset, size)
450 .wait()));
451}
452catch (sycl::exception const &exc) {
453 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
454 << ", line:" << __LINE__ << std::endl;
455 std::exit(1);
456}
457
458static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
459 const void *ptr_src, size_t size) {
460 char *host_buf = (char *)malloc(size);
461 q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
462 q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
463 free(host_buf);
464}
465
466static bool
467ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
468 const ggml_tensor *src,
469 ggml_tensor *dst) try {
470 bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
471 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
472 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
473 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
474 GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
475 if (is_cpy_supported) {
476 ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
477 ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
478
479 ggml_sycl_set_device(src_ctx->device);
480 /*
481 DPCT1009:198: SYCL uses exceptions to report errors and does not use the
482 error codes. The original code was commented out and a warning string
483 was inserted. You need to rewrite this code.
484 */
485 SYCL_CHECK(CHECK_TRY_ERROR(
486 dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
487 ggml_sycl_set_device(dst_ctx->device);
488 /*
489 DPCT1009:199: SYCL uses exceptions to report errors and does not use the
490 error codes. The original code was commented out and a warning string
491 was inserted. You need to rewrite this code.
492 */
493 SYCL_CHECK(CHECK_TRY_ERROR(
494 dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
495 /*
496 DPCT1009:200: SYCL uses exceptions to report errors and does not use the
497 error codes. The original code was commented out and a warning string
498 was inserted. You need to rewrite this code.
499 */
500
501 queue_ptr stream_dst = dst_ctx->stream;
502 queue_ptr stream_src = src_ctx->stream;
503 size_t size = ggml_nbytes(src);
504
505 //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
506 dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
507
508//todo, it's known issue๏ผerror in device2device cross GPUs. reused when the issue is fixed. DON"T remove
509#if 0
510 SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
511 (char *)dst->data, (const char *)src->data, size).wait()));
512
513 /*
514 DPCT1009:201: SYCL uses exceptions to report errors and does not use the
515 error codes. The original code was commented out and a warning string
516 was inserted. You need to rewrite this code.
517 */
518 SYCL_CHECK(CHECK_TRY_ERROR(
519 dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
520#endif
521 return true;
522 }
523 return false;
524 GGML_UNUSED(buffer);
525} catch (const sycl::exception & exc) {
526 std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
527 std::exit(1);
528}
529
530static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
531 uint8_t value) try {
532 GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
533 ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
534
535 ggml_sycl_set_device(ctx->device);
536 queue_ptr stream = ctx->stream;
537 SYCL_CHECK(
538 CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
539
540 SYCL_CHECK(CHECK_TRY_ERROR((*stream)
541 .memset(ctx->dev_ptr, value, buffer->size)
542 .wait()));
543}
544catch (sycl::exception const &exc) {
545 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
546 << ", line:" << __LINE__ << std::endl;
547 std::exit(1);
548}
549
550static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
551 size_t offset, size_t size) {
552 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
553 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
554 GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
555 ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
556 SYCL_CHECK(ggml_sycl_set_device(ctx->device));
557 auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
558 if (size == 0) {
559 return; // Nothing to do
560 }
561 if (tensor->data == nullptr) {
562 GGML_ABORT("Error: Tensor data pointer is null.\n");
563 }
564 void * target_ptr = static_cast<char *>(tensor->data) + offset;
565 SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
566 SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
567}
568
569static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
570 GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
571 if (buffer == nullptr) {
572 return;
573 }
574
575 ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
576
577 if (ctx != nullptr) {
578 for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
579 release_extra_gpu(extra);
580 }
581 ctx->tensor_extras.clear(); // reset the tensor_extras vector
582 }
583}
584
585static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
586 /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
587 /* .get_base = */ ggml_backend_sycl_buffer_get_base,
588 /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
589 /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
590 /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
591 /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
592 /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
593 /* .clear = */ ggml_backend_sycl_buffer_clear,
594 /* .reset = */ ggml_backend_sycl_buffer_reset,
595};
596
597// sycl buffer type
598struct ggml_backend_sycl_buffer_type_context {
599 int device;
600 std::string name;
601
602 // each buffer type has its own stream
603 queue_ptr stream = nullptr;
604};
605
606static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
607 ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
608
609 return ctx->name.c_str();
610}
611
612static ggml_backend_buffer_t
613ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
614 size_t size) try {
615 ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
616 ggml_sycl_set_device(buft_ctx->device);
617 const queue_ptr stream = buft_ctx->stream;
618 size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
619
620 void * dev_ptr;
621 SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
622 size, *stream)));
623 if (!dev_ptr) {
624 GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
625 return nullptr;
626 }
627 ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
628 return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
629}
630catch (sycl::exception const &exc) {
631 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
632 << ", line:" << __LINE__ << std::endl;
633 std::exit(1);
634}
635
636static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
637 return 128;
638 GGML_UNUSED(buft);
639}
640
641static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
642 return dpct::get_current_device().get_max_mem_alloc_size();
643
644 GGML_UNUSED(buft);
645}
646
647static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
648 size_t size = ggml_nbytes(tensor);
649 int64_t ne0 = tensor->ne[0];
650
651 if (ggml_is_quantized(tensor->type)) {
652 if (ne0 % MATRIX_ROW_PADDING != 0) {
653 size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
654 }
655 }
656
657 return size;
658
659 GGML_UNUSED(buft);
660}
661
662static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
663 /* .get_name = */ ggml_backend_sycl_buffer_type_get_name,
664 /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer,
665 /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment,
666 /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size,
667 /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size,
668 /* .is_host = */ NULL,
669};
670
671ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
672 static std::mutex mutex;
673 std::lock_guard<std::mutex> lock(mutex);
674
675
676 auto dev_count = ggml_backend_sycl_get_device_count();
677
678 if (device>=dev_count or device<0) {
679 GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
680 device, dev_count-1);
681 GGML_ASSERT(device<dev_count);
682 }
683 static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
684
685 static bool ggml_backend_sycl_buffer_type_initialized = false;
686
687 if (!ggml_backend_sycl_buffer_type_initialized) {
688 for (int i = 0; i < dev_count; i++) {
689 auto & device_i = dpct::dev_mgr::instance().get_device(i);
690 queue_ptr stream = &(device_i.default_queue());
691 ggml_backend_sycl_buffer_types[i] = {
692 /* .iface = */ ggml_backend_sycl_buffer_type_interface,
693 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
694 /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
695 };
696 }
697 ggml_backend_sycl_buffer_type_initialized = true;
698 }
699 return &ggml_backend_sycl_buffer_types[device];
700}
701
702static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
703 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
704
705 int device = ctx->device;
706 if (device>=ggml_sycl_info().device_count or device<0) {
707 GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
708 device, ggml_sycl_info().device_count-1);
709 GGML_ASSERT(device<ggml_sycl_info().device_count);
710 }
711 static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
712
713 static bool ggml_backend_sycl_buffer_type_initialized = false;
714
715 if (!ggml_backend_sycl_buffer_type_initialized) {
716 for (int i = 0; i < ggml_sycl_info().device_count; i++) {
717 ggml_backend_sycl_buffer_types[i] = {
718 /* .iface = */ ggml_backend_sycl_buffer_type_interface,
719 /* .device = */ nullptr,
720 /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
721 };
722 }
723 ggml_backend_sycl_buffer_type_initialized = true;
724 }
725 return &ggml_backend_sycl_buffer_types[device];
726}
727
728// sycl split buffer
729
730static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
731 int64_t min_compute_capability = INT_MAX;
732 int64_t max_compute_capability = INT_MIN;
733 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
734 if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
735 if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
736 min_compute_capability = ggml_sycl_info().devices[i].cc;
737 }
738 if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
739 max_compute_capability = ggml_sycl_info().devices[i].cc;
740 }
741 }
742 }
743
744 switch(type) {
745 case GGML_TYPE_Q4_0:
746 case GGML_TYPE_Q4_1:
747 return max_compute_capability >= VER_GEN9 ? 128 : 64;
748 case GGML_TYPE_Q5_0:
749 case GGML_TYPE_Q5_1:
750 case GGML_TYPE_Q8_0:
751 return 64;
752 case GGML_TYPE_F16:
753 case GGML_TYPE_F32:
754 return 1;
755 case GGML_TYPE_Q2_K:
756 case GGML_TYPE_Q3_K:
757 case GGML_TYPE_Q4_K:
758 case GGML_TYPE_Q5_K:
759 case GGML_TYPE_IQ2_XXS:
760 case GGML_TYPE_IQ2_XS:
761 case GGML_TYPE_IQ2_S:
762 case GGML_TYPE_IQ1_S:
763 case GGML_TYPE_IQ1_M:
764 case GGML_TYPE_IQ3_XXS:
765 case GGML_TYPE_IQ4_XS:
766 case GGML_TYPE_IQ4_NL:
767 return max_compute_capability >= VER_GEN9 ? 128 : 64;
768 case GGML_TYPE_IQ3_S:
769 return max_compute_capability >= VER_GEN9 ? 128 : 64;
770 case GGML_TYPE_Q6_K:
771 return 64;
772 default:
773 GGML_ABORT("fatal error");
774 }
775}
776
777static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
778 const int64_t nrows = ggml_nrows(tensor);
779 const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
780
781 *row_low = id == 0 ? 0 : nrows*tensor_split[id];
782 *row_low -= *row_low % rounding;
783 if (id == ggml_sycl_info().device_count - 1) {
784 *row_high = nrows;
785 } else {
786 *row_high = nrows*tensor_split[id + 1];
787 *row_high -= *row_high % rounding;
788 }
789}
790
791static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
792 static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
793
794 return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
795}
796
797struct ggml_backend_sycl_split_buffer_type_context {
798 std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
799};
800
801struct ggml_backend_sycl_split_buffer_context {
802 ~ggml_backend_sycl_split_buffer_context() try {
803 for (ggml_tensor_extra_gpu * extra : tensor_extras) {
804 release_extra_gpu(extra, streams);
805 }
806 }
807 catch (sycl::exception const &exc) {
808 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
809 << ", line:" << __LINE__ << std::endl;
810 std::exit(1);
811 }
812
813 std::vector<ggml_tensor_extra_gpu *> tensor_extras;
814 std::vector<queue_ptr> streams;
815};
816
817static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
818 ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
819 delete ctx;
820}
821
822static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
823 // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
824 return (void *)0x1000;
825
826 GGML_UNUSED(buffer);
827}
828
829static enum ggml_status
830ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
831 ggml_tensor *tensor) try {
832 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
833 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
834 GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
835
836 ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
837 ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
838
839 const int64_t ne0 = tensor->ne[0];
840
841 ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
842
843 ctx->tensor_extras.push_back(extra);
844 ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
845
846 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
847 int64_t row_low, row_high;
848 get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
849
850 int64_t nrows_split = row_high - row_low;
851 if (nrows_split == 0) {
852 continue;
853 }
854
855 size_t size = ggml_nbytes_split(tensor, nrows_split);
856 const size_t original_size = size;
857
858 // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
859 if (ne0 % MATRIX_ROW_PADDING != 0) {
860 size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
861 }
862
863 // FIXME: do not crash if SYCL Buffer alloc fails
864 // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
865 ggml_sycl_set_device(i);
866 const queue_ptr stream = ctx->streams[i];
867 char * buf;
868 /*
869 DPCT1009:208: SYCL uses exceptions to report errors and does not use the
870 error codes. The original code was commented out and a warning string
871 was inserted. You need to rewrite this code.
872 */
873 SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
874 size, *stream)));
875 if (!buf) {
876 char err_buf[1024];
877 snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
878 throw std::runtime_error(err_buf);
879 }
880 // set padding to 0 to avoid possible NaN values
881 if (size > original_size) {
882 /*
883 DPCT1009:209: SYCL uses exceptions to report errors and does not use
884 the error codes. The original code was commented out and a warning
885 string was inserted. You need to rewrite this code.
886 */
887 SYCL_CHECK(CHECK_TRY_ERROR(
888 (*stream)
889 .memset(buf + original_size, 0, size - original_size)
890 .wait()));
891 }
892
893 extra->data_device[i] = buf;
894
895 for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
896 /*
897 DPCT1009:210: SYCL uses exceptions to report errors and does not use
898 the error codes. The original code was commented out and a warning
899 string was inserted. You need to rewrite this code.
900 */
901 SYCL_CHECK(
902 CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
903 }
904 }
905 tensor->extra = extra;
906 return GGML_STATUS_SUCCESS;
907}
908catch (sycl::exception const &exc) {
909 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
910 << ", line:" << __LINE__ << std::endl;
911 std::exit(1);
912}
913
914static void
915ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
916 ggml_tensor *tensor, const void *data,
917 size_t offset, size_t size) try {
918 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
919 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
920 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
921 // split tensors must always be set in their entirety at once
922 GGML_ASSERT(offset == 0);
923 GGML_ASSERT(size == ggml_nbytes(tensor));
924
925 ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
926 ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
927
928 const int64_t ne0 = tensor->ne[0];
929 const size_t nb1 = tensor->nb[1];
930 ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
931
932 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
933 int64_t row_low, row_high;
934 get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
935
936 int64_t nrows_split = row_high - row_low;
937 if (nrows_split == 0) {
938 continue;
939 }
940
941 const size_t offset_split = row_low*nb1;
942 size_t size = ggml_nbytes_split(tensor, nrows_split);
943 const size_t original_size = size;
944
945 // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
946 if (ne0 % MATRIX_ROW_PADDING != 0) {
947 size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
948 }
949
950 const char * buf_host = (const char *)data + offset_split;
951 /*
952 DPCT1009:211: SYCL uses exceptions to report errors and does not use the
953 error codes. The original code was commented out and a warning string
954 was inserted. You need to rewrite this code.
955 */
956 ggml_sycl_set_device(i);
957 const queue_ptr stream = ctx->streams[i];
958 SYCL_CHECK(CHECK_TRY_ERROR(
959 (*stream)
960 .memcpy(extra->data_device[i], buf_host, original_size)
961 .wait()));
962 }
963}
964catch (sycl::exception const &exc) {
965 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
966 << ", line:" << __LINE__ << std::endl;
967 std::exit(1);
968}
969
970static void
971ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
972 const ggml_tensor *tensor, void *data,
973 size_t offset, size_t size) try {
974 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
975 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
976 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
977 // split tensors must always be set in their entirety at once
978 GGML_ASSERT(offset == 0);
979 GGML_ASSERT(size == ggml_nbytes(tensor));
980
981 ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
982 ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
983
984 const int64_t ne0 = tensor->ne[0];
985 const size_t nb1 = tensor->nb[1];
986 ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
987
988 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
989 int64_t row_low, row_high;
990 get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
991
992 int64_t nrows_split = row_high - row_low;
993 if (nrows_split == 0) {
994 continue;
995 }
996
997 const size_t offset_split = row_low*nb1;
998 size_t size = ggml_nbytes_split(tensor, nrows_split);
999 const size_t original_size = size;
1000
1001 // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
1002 if (ne0 % MATRIX_ROW_PADDING != 0) {
1003 size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1004 }
1005
1006 char * buf_host = (char *)data + offset_split;
1007 /*
1008 DPCT1009:212: SYCL uses exceptions to report errors and does not use the
1009 error codes. The original code was commented out and a warning string
1010 was inserted. You need to rewrite this code.
1011 */
1012 ggml_sycl_set_device(i);
1013 const queue_ptr stream = ctx->streams[i];
1014 SYCL_CHECK(CHECK_TRY_ERROR(
1015 (*stream)
1016 .memcpy(buf_host, extra->data_device[i], original_size)
1017 .wait()));
1018 }
1019}
1020catch (sycl::exception const &exc) {
1021 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1022 << ", line:" << __LINE__ << std::endl;
1023 std::exit(1);
1024}
1025
1026static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1027 GGML_UNUSED(buffer);
1028 GGML_UNUSED(value);
1029}
1030
1031static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
1032 /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer,
1033 /* .get_base = */ ggml_backend_sycl_split_buffer_get_base,
1034 /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor,
1035 /* .memset_tensor = */ NULL,
1036 /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor,
1037 /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor,
1038 /* .cpy_tensor = */ NULL,
1039 /* .clear = */ ggml_backend_sycl_split_buffer_clear,
1040 /* .reset = */ NULL,
1041};
1042
1043// sycl split buffer type
1044
1045static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1046 return GGML_SYCL_NAME "_Split";
1047
1048 GGML_UNUSED(buft);
1049}
1050
1051static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
1052 return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;
1053}
1054
1055static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1056 // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
1057 // instead, we allocate them for each tensor separately in init_tensor
1058 // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
1059 // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
1060 ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
1061
1062 return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
1063}
1064
1065static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1066 return 128;
1067 GGML_UNUSED(buft);
1068}
1069
1070static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
1071 ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
1072
1073 size_t total_size = 0;
1074
1075 const int64_t ne0 = tensor->ne[0];
1076
1077 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1078 int64_t row_low, row_high;
1079 get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
1080
1081 int64_t nrows_split = row_high - row_low;
1082 if (nrows_split == 0) {
1083 continue;
1084 }
1085
1086 total_size += ggml_nbytes_split(tensor, nrows_split);
1087
1088 // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
1089 if (ne0 % MATRIX_ROW_PADDING != 0) {
1090 total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1091 }
1092 }
1093
1094 return total_size;
1095}
1096
1097static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1098 return false;
1099
1100 GGML_UNUSED(buft);
1101}
1102
1103static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
1104 /* .get_name = */ ggml_backend_sycl_split_buffer_type_get_name,
1105 /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
1106 /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment,
1107 /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1108 /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
1109 /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host,
1110};
1111
1112ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
1113 static std::mutex mutex;
1114 std::lock_guard<std::mutex> lock(mutex);
1115
1116 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
1117 ggml_check_sycl();
1118 // FIXME: this is not thread safe
1119 static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
1120
1121 std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
1122
1123 bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
1124 if (all_zero) {
1125 tensor_split_arr = ggml_sycl_info().default_tensor_split;
1126 } else {
1127 float split_sum = 0.0f;
1128 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1129 tensor_split_arr[i] = split_sum;
1130 split_sum += tensor_split[i];
1131 }
1132 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
1133 tensor_split_arr[i] /= split_sum;
1134 }
1135 }
1136
1137 auto it = buft_map.find(tensor_split_arr);
1138 if (it != buft_map.end()) {
1139 return &it->second;
1140 }
1141
1142 struct ggml_backend_buffer_type buft {
1143 /* .iface = */ ggml_backend_sycl_split_buffer_type_interface,
1144 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
1145 /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
1146 };
1147
1148 auto result = buft_map.emplace(tensor_split_arr, buft);
1149 return &result.first->second;
1150}
1151
1152// host buffer type
1153
1154static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1155 return GGML_SYCL_NAME "_Host";
1156
1157 GGML_UNUSED(buft);
1158}
1159
1160inline void * aligned_malloc_host(size_t alignment, size_t size) {
1161#ifdef _WIN32
1162 return _aligned_malloc(size, alignment);
1163#else
1164 return aligned_alloc(alignment, size);
1165#endif
1166}
1167
1168inline void free_aligned_mem_host(void * memblock) {
1169#ifdef _WIN32
1170 _aligned_free(memblock);
1171#else
1172 free(memblock);
1173#endif
1174}
1175
1176static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1177 free_aligned_mem_host((void *)buffer->context);
1178}
1179
1180static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1181 void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
1182 if (ptr == nullptr) {
1183 // fallback to cpu buffer
1184 return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1185 }
1186
1187 // FIXME: this is a hack to avoid having to implement a new buffer type
1188 ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
1189 buffer->buft = buft;
1190 buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
1191
1192 return buffer;
1193}
1194
1195ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
1196 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
1197 static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
1198 /* .iface = */ {
1199 /* .get_name = */ ggml_backend_sycl_host_buffer_type_name,
1200 /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
1201 /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1202 /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
1203 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1204 /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1205 },
1206 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
1207 /* .context = */ nullptr,
1208 };
1209
1210 return &ggml_backend_sycl_buffer_type_host;
1211}
1212
1213// buffer pool for sycl (legacy)
1214struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1215 static const int MAX_SYCL_BUFFERS = 256;
1216
1217 int device;
1218 queue_ptr qptr;
1219 struct ggml_sycl_buffer {
1220 void * ptr = nullptr;
1221 size_t size = 0;
1222 };
1223
1224 ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
1225 size_t pool_size = 0;
1226
1227 explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}
1228
1229 ~ggml_sycl_pool_leg() {
1230 for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1231 ggml_sycl_buffer & b = buffer_pool[i];
1232 if (b.ptr != nullptr) {
1233 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1234 pool_size -= b.size;
1235 }
1236 }
1237 GGML_ASSERT(pool_size == 0);
1238 }
1239
1240 void * alloc(size_t size, size_t * actual_size) override {
1241#ifdef DEBUG_sycl_MALLOC
1242 int nnz = 0;
1243 size_t max_size = 0;
1244#endif
1245 size_t best_diff = 1ull << 36;
1246 int ibest = -1;
1247 for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1248 ggml_sycl_buffer& b = buffer_pool[i];
1249 if (b.ptr != nullptr) {
1250#ifdef DEBUG_sycl_MALLOC
1251 ++nnz;
1252 if (b.size > max_size) max_size = b.size;
1253#endif
1254 if (b.size >= size) {
1255 size_t diff = b.size - size;
1256 if (diff < best_diff) {
1257 best_diff = diff;
1258 ibest = i;
1259 if (!best_diff) {
1260 void * ptr = b.ptr;
1261 *actual_size = b.size;
1262 b.ptr = nullptr;
1263 b.size = 0;
1264 return ptr;
1265 }
1266 }
1267 }
1268 }
1269 }
1270 if (ibest >= 0) {
1271 ggml_sycl_buffer& b = buffer_pool[ibest];
1272 void * ptr = b.ptr;
1273 *actual_size = b.size;
1274 b.ptr = nullptr;
1275 b.size = 0;
1276 return ptr;
1277 }
1278 void * ptr;
1279 size_t look_ahead_size = (size_t) (1.05 * size);
1280
1281 SYCL_CHECK(
1282 CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
1283 look_ahead_size, *qptr)));
1284 if (!ptr) {
1285 GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size);
1286 return nullptr;
1287 }
1288
1289 *actual_size = look_ahead_size;
1290 pool_size += look_ahead_size;
1291
1292#ifdef DEBUG_SYCL_MALLOC
1293 GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
1294 (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
1295#endif
1296
1297 // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
1298 return ptr;
1299 }
1300
1301 void free(void * ptr, size_t size) override {
1302 for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
1303 ggml_sycl_buffer& b = buffer_pool[i];
1304 if (b.ptr == nullptr) {
1305 b.ptr = ptr;
1306 b.size = size;
1307 return;
1308 }
1309 }
1310 GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
1311 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
1312 pool_size -= size;
1313 }
1314};
1315
1316struct ggml_sycl_pool_host : public ggml_sycl_pool {
1317 queue_ptr qptr;
1318 int device;
1319
1320 inline static int counter{ 0 };
1321
1322 struct ggml_sycl_buffer {
1323 void * ptr = nullptr;
1324 size_t size = 0;
1325 };
1326
1327 // Set arbitrarly to 64
1328 static constexpr int MAX_POOL_SIZE{ 64 };
1329 std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1330 size_t pool_size = 0;
1331
1332 explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1333
1334 ~ggml_sycl_pool_host() {
1335 for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1336 ggml_sycl_buffer & b = buffer_pool[i];
1337 if (b.ptr != nullptr) {
1338 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1339 b.ptr = nullptr;
1340 pool_size -= b.size;
1341 b.size = 0;
1342 }
1343 }
1344 counter = 0;
1345 }
1346
1347 void * alloc(size_t size, size_t * actual_size) override {
1348 if (counter == MAX_POOL_SIZE) {
1349 ggml_sycl_buffer b = buffer_pool[0];
1350 void * ptr = b.ptr;
1351 *actual_size = b.size;
1352 counter = 1;
1353 return ptr;
1354 }
1355 ggml_sycl_buffer & b = buffer_pool[counter];
1356
1357 if (b.ptr == nullptr) {
1358 void * ptr;
1359
1360 SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1361 if (!ptr) {
1362 GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1363 return nullptr;
1364 }
1365 pool_size += size;
1366 *actual_size = size;
1367 counter = counter + 1;
1368 return ptr;
1369 } else {
1370 ++counter;
1371 b.size = size;
1372 return b.ptr;
1373 }
1374 }
1375
1376 void free(void * ptr, size_t size) override {
1377 // if the pool is not completed add the pointer to it in place of the first nullptr found.
1378 // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1379 for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1380 ggml_sycl_buffer & b = buffer_pool[i];
1381 if (b.ptr == nullptr) {
1382 b.ptr = ptr;
1383 b.size = size;
1384 return;
1385 }
1386 }
1387 }
1388};
1389
1390std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1391 // return pool for the host to speed up memory management
1392 return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1393}
1394
1395std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1396 // TBD: NO VMM support
1397 // if (ggml_sycl_info().devices[device].vmm) {
1398 // return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
1399 // }
1400 return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
1401}
1402
1403// TBD pool with virtual memory management
1404// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1405
1406/// kernels
1407typedef void (*ggml_sycl_op_mul_mat_t)(
1408 ggml_backend_sycl_context & ctx,
1409 const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
1410 const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
1411 float *dst_dd_i, const int64_t row_low, const int64_t row_high,
1412 const int64_t src1_ncols, const int64_t src1_padded_row_size,
1413 const queue_ptr &stream);
1414
1415
1416
1417static void mul_mat_p021_f16_f32(
1418 const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1419 const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
1420 const sycl::nd_item<3> &item_ct1) {
1421
1422 const sycl::half *x = (const sycl::half *)vx;
1423
1424 const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1425 item_ct1.get_local_id(1);
1426 const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1427 item_ct1.get_local_id(0);
1428 const int channel_x = channel / (nchannels_y / nchannels_x);
1429
1430 const int nrows_y = ncols_x;
1431 const int nrows_dst = nrows_x;
1432 const int row_dst = row_x;
1433
1434 float tmp = 0.0f;
1435
1436 for (int col_x0 = 0; col_x0 < ncols_x;
1437 col_x0 += item_ct1.get_local_range(2)) {
1438 const int col_x = col_x0 + item_ct1.get_local_id(2);
1439
1440 if (col_x >= ncols_x) {
1441 break;
1442 }
1443
1444 // x is transposed and permuted
1445 const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
1446 const float xi =
1447 sycl::vec<sycl::half, 1>(x[ix])
1448 .convert<float, sycl::rounding_mode::automatic>()[0];
1449
1450 const int row_y = col_x;
1451
1452
1453 // y is not transposed but permuted
1454 const int iy = channel*nrows_y + row_y;
1455
1456 tmp += xi * y[iy];
1457 }
1458
1459 // dst is not transposed and not permuted
1460 const int idst = channel*nrows_dst + row_dst;
1461
1462 // sum up partial sums and write back result
1463#pragma unroll
1464 for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
1465 tmp +=
1466 dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
1467 }
1468
1469 if (item_ct1.get_local_id(2) == 0) {
1470 dst[idst] = tmp;
1471 }
1472}
1473
1474static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1475 const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1476 const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
1477 const sycl::nd_item<3> &item_ct1) {
1478
1479 const sycl::half *x = (const sycl::half *)vx;
1480
1481 const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1482 item_ct1.get_local_id(1);
1483 const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1484 item_ct1.get_local_id(0);
1485 const int channel_x = channel / channel_x_divisor;
1486
1487 const int nrows_dst = nrows_x;
1488 const int row_dst = row_x;
1489
1490 const int idst = channel*nrows_dst + row_dst;
1491
1492 float tmp = 0.0f;
1493
1494 for (int col_x0 = 0; col_x0 < ncols_x;
1495 col_x0 += item_ct1.get_local_range(2)) {
1496 const int col_x = col_x0 + item_ct1.get_local_id(2);
1497
1498 if (col_x >= ncols_x) {
1499 break;
1500 }
1501
1502 const int row_y = col_x;
1503
1504 const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1505 const int iy = channel * channel_stride_y + row_y;
1506
1507 const float xi =
1508 sycl::vec<sycl::half, 1>(x[ix])
1509 .convert<float, sycl::rounding_mode::automatic>()[0];
1510
1511 tmp += xi * y[iy];
1512 }
1513
1514 // sum up partial sums and write back result
1515#pragma unroll
1516 for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
1517 tmp +=
1518 dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
1519 }
1520
1521 if (item_ct1.get_local_id(2) == 0) {
1522 dst[idst] = tmp;
1523 }
1524}
1525
1526static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1527 const sycl::nd_item<3> &item_ct1) {
1528 const int row = item_ct1.get_group(1);
1529 const int col = item_ct1.get_local_id(2);
1530
1531 float sum = 0.0f;
1532 for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
1533 sum += x[row * ncols + i];
1534 }
1535
1536 sum = warp_reduce_sum(sum, item_ct1);
1537
1538 if (col == 0) {
1539 dst[row] = sum;
1540 }
1541}
1542
1543
1544template<typename T>
1545static inline void ggml_sycl_swap(T & a, T & b) {
1546 T tmp = a;
1547 a = b;
1548 b = tmp;
1549}
1550
1551template <ggml_sort_order order>
1552__dpct_inline__ static void
1553k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1554 const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
1555 uint8_t *dpct_local) {
1556 // bitonic sort
1557 int col_index = item_ct1.get_local_id(2);
1558 int row = item_ct1.get_group(1);
1559
1560 for (int i = 0; i < tasks_per_thread; i++) {
1561 int col = col_index * tasks_per_thread + i;
1562 if (col >= ncols_pad) {
1563 return;
1564 }
1565 }
1566
1567 const float * x_row = x + row * ncols;
1568 auto dst_row = (int *)dpct_local;
1569
1570 // initialize indices
1571 for (int i=0;i<tasks_per_thread;i++){
1572 int col = col_index*tasks_per_thread+i;
1573 dst_row[col] = col;
1574 }
1575
1576 item_ct1.barrier(sycl::access::fence_space::local_space);
1577
1578 for (int k = 2; k <= ncols_pad; k *= 2) {
1579 for (int j = k / 2; j > 0; j /= 2) {
1580 for (int i = 0; i < tasks_per_thread; i++) {
1581 int col = col_index * tasks_per_thread + i;
1582 int ixj = col ^ j;
1583 if (ixj > col) {
1584 if ((col & k) == 0) {
1585 if (dst_row[col] >= ncols ||
1586 (dst_row[ixj] < ncols &&
1587 (order == GGML_SORT_ORDER_ASC
1588 ? x_row[dst_row[col]] > x_row[dst_row[ixj]]
1589 : x_row[dst_row[col]] <
1590 x_row[dst_row[ixj]]))) {
1591 ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1592 }
1593 } else {
1594 if (dst_row[ixj] >= ncols ||
1595 (dst_row[col] < ncols &&
1596 (order == GGML_SORT_ORDER_ASC
1597 ? x_row[dst_row[col]] < x_row[dst_row[ixj]]
1598 : x_row[dst_row[col]] >
1599 x_row[dst_row[ixj]]))) {
1600 ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1601 }
1602 }
1603 }
1604 item_ct1.barrier(sycl::access::fence_space::local_space);
1605 }
1606 }
1607 }
1608
1609 // copy the result to dst without the padding
1610 for (int i = 0; i < tasks_per_thread; i++) {
1611 int col = col_index * tasks_per_thread + i;
1612 if (col < ncols) {
1613 dst[row * ncols + col] = dst_row[col];
1614 }
1615 }
1616}
1617
1618static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1619 const sycl::nd_item<3> &item_ct1) {
1620 const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1621 item_ct1.get_local_id(1);
1622 const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1623 item_ct1.get_local_id(2);
1624
1625 if (col >= ncols) {
1626 return;
1627 }
1628
1629 const int i = row*ncols + col;
1630 //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
1631 //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
1632 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1633}
1634
1635static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
1636 const sycl::nd_item<3> &item_ct1) {
1637 const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1638 item_ct1.get_local_id(2);
1639
1640 if (i >= k) {
1641 return;
1642 }
1643
1644 dst[i] = scale * x[i] + bias;
1645}
1646
1647
1648template <typename Ti, typename To>
1649static void pool2d_nchw_kernel(
1650 const int ih, const int iw, const int oh, const int ow,
1651 const int kh, const int kw, const int sh, const int sw,
1652 const int ph, const int pw, const int parallel_elements,
1653 const Ti* src, To* dst, const enum ggml_op_pool op,
1654 const sycl::nd_item<3> &item_ct1) {
1655 int idx = item_ct1.get_local_id(2) +
1656 item_ct1.get_group(2) * item_ct1.get_local_range(2);
1657 if (idx >= parallel_elements) {
1658 return;
1659 }
1660
1661 const int I_HW = ih * iw;
1662 const int O_HW = oh * ow;
1663 const int nc = idx / O_HW;
1664 const int cur_oh = idx % O_HW / ow;
1665 const int cur_ow = idx % O_HW % ow;
1666 const Ti* i_ptr = src + nc * I_HW;
1667 To* o_ptr = dst + nc * O_HW;
1668 const int start_h = cur_oh * sh - ph;
1669 const int bh = sycl::max(0, start_h);
1670 const int eh = sycl::min(ih, start_h + kh);
1671 const int start_w = cur_ow * sw - pw;
1672 const int bw = sycl::max(0, start_w);
1673 const int ew = sycl::min(iw, start_w + kw);
1674
1675 To res = 0;
1676
1677 switch (op) {
1678 case GGML_OP_POOL_AVG: res = 0; break;
1679 case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
1680 default:
1681 res = (To) sycl::nan(uint32_t(0));
1682 break;
1683 }
1684
1685 for (int i = bh; i < eh; i += 1) {
1686 for (int j = bw; j < ew; j += 1) {
1687#if DPCT_COMPATIBILITY_TEMP >= 350
1688 /*
1689 DPCT1098:106: The '*' expression is used instead of the __ldg
1690 call. These two expressions do not provide the exact same
1691 functionality. Check the generated code for potential precision
1692 and/or performance issues.
1693 */
1694 Ti cur = *(i_ptr + i * iw + j);
1695#else
1696 Ti cur = i_ptr[i * iw + j];
1697#endif
1698 switch (op) {
1699 case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
1700 case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
1701 default:
1702 res = (To) sycl::nan(uint32_t(0));
1703 break;
1704 }
1705 }
1706 }
1707 o_ptr[cur_oh * ow + cur_ow] = res;
1708}
1709
1710
1711static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1712 float *dst, const int ncols_x,
1713 const int nrows_x,
1714 const int nchannels_x,
1715 const int nchannels_y,
1716 queue_ptr stream) {
1717
1718 const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1719 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1720 {
1721 dpct::has_capability_or_fail(stream->get_device(),
1722 {sycl::aspect::fp16});
1723
1724 stream->parallel_for(
1725 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1726 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1727 mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
1728 nchannels_y, item_ct1);
1729 });
1730 }
1731}
1732
1733static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1734 const void *vx, const float *y, float *dst, const int ncols_x,
1735 const int nrows_x, const int row_stride_x, const int nchannels_x,
1736 const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
1737
1738 const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1739 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1740 {
1741 dpct::has_capability_or_fail(stream->get_device(),
1742 {sycl::aspect::fp16});
1743
1744 stream->parallel_for(
1745 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1746 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1747 mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1748 row_stride_x, channel_stride_x, channel_stride_y,
1749 nchannels_y / nchannels_x, item_ct1);
1750 });
1751 }
1752}
1753
1754
1755
1756static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
1757 const int k, queue_ptr stream) {
1758 const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
1759 stream->parallel_for(
1760 sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1761 sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
1762 sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
1763 [=](sycl::nd_item<3> item_ct1) {
1764 scale_f32(x, dst, scale, bias, k, item_ct1);
1765 });
1766}
1767
1768
1769static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1770 const int nrows, queue_ptr stream) {
1771 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
1772 const sycl::range<3> block_nums(1, nrows, 1);
1773 stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1774 [=](sycl::nd_item<3> item_ct1)
1775 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1776 k_sum_rows_f32(x, dst, ncols, item_ct1);
1777 });
1778}
1779
1780static int next_power_of_2(int x) {
1781 int n = 1;
1782 while (n < x) {
1783 n *= 2;
1784 }
1785 return n;
1786}
1787
1788static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1789 const int nrows, ggml_sort_order order,
1790 queue_ptr stream, int device) {
1791 // bitonic sort requires ncols to be power of 2
1792 const int ncols_pad = next_power_of_2(ncols);
1793
1794 int nth = 1;
1795 int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
1796 while (nth < ncols_pad && nth < max_block_size)
1797 nth *= 2;
1798 if (nth > max_block_size)
1799 nth = max_block_size;
1800
1801 const int tasks_per_thread = ncols_pad / nth;
1802
1803 const sycl::range<3> block_dims(1, 1, nth);
1804 const sycl::range<3> block_nums(1, nrows, 1);
1805 const size_t shared_mem = ncols_pad * sizeof(int);
1806 GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
1807
1808 if (order == GGML_SORT_ORDER_ASC) {
1809 stream->submit([&](sycl::handler &cgh) {
1810 sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1811 sycl::range<1>(shared_mem), cgh);
1812
1813 cgh.parallel_for(
1814 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1815 [=](sycl::nd_item<3> item_ct1) {
1816 k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1817 x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1818 dpct_local_acc_ct1
1819 .get_multi_ptr<sycl::access::decorated::no>()
1820 .get());
1821 });
1822 });
1823 } else if (order == GGML_SORT_ORDER_DESC) {
1824 stream->submit([&](sycl::handler &cgh) {
1825 sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1826 sycl::range<1>(shared_mem), cgh);
1827
1828 cgh.parallel_for(
1829 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1830 [=](sycl::nd_item<3> item_ct1) {
1831 k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1832 x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1833 dpct_local_acc_ct1
1834 .get_multi_ptr<sycl::access::decorated::no>()
1835 .get());
1836 });
1837 });
1838 } else {
1839 GGML_ABORT("fatal error");
1840 }
1841}
1842
1843static void top_k_f32_sycl(
1844 const float * src,
1845 int32_t * dst_indices,
1846 const int64_t ncols,
1847 const int64_t nrows,
1848 const int k,
1849 dpct::queue_ptr main_stream
1850) {
1851 const int block_size = 128;
1852
1853 const sycl::range<1> block_dims(block_size);
1854 const sycl::range<1> grid_dims(nrows);
1855
1856 main_stream->submit([&](sycl::handler &cgh) {
1857 sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
1858 sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
1859
1860 cgh.parallel_for(
1861 sycl::nd_range<1>(grid_dims * block_dims, block_dims),
1862 [=](sycl::nd_item<1> item_ct1) {
1863 const int row = item_ct1.get_group(0);
1864 const int tid = item_ct1.get_local_id(0);
1865
1866 if (row >= nrows) return;
1867
1868 const float * src_row = src + row * ncols;
1869 int32_t * dst_idx_row = dst_indices + row * k;
1870
1871 float local_vals[32];
1872 int local_idx[32];
1873
1874 for (int i = 0; i < k; i++) {
1875 local_vals[i] = -FLT_MAX;
1876 local_idx[i] = -1;
1877 }
1878
1879 for (int col = tid; col < ncols; col += block_size) {
1880 float val = src_row[col];
1881
1882 if (val > local_vals[k-1]) {
1883 int pos = k - 1;
1884 while (pos > 0 && val > local_vals[pos - 1]) {
1885 pos--;
1886 }
1887
1888 for (int i = k - 1; i > pos; i--) {
1889 local_vals[i] = local_vals[i - 1];
1890 local_idx[i] = local_idx[i - 1];
1891 }
1892 local_vals[pos] = val;
1893 local_idx[pos] = col;
1894 }
1895 }
1896
1897 for (int i = 0; i < k; i++) {
1898 shared_vals[tid * k + i] = local_vals[i];
1899 shared_idx[tid * k + i] = local_idx[i];
1900 }
1901 item_ct1.barrier(sycl::access::fence_space::local_space);
1902
1903 if (tid == 0) {
1904 float final_vals[32];
1905 int final_idx[32];
1906
1907 for (int i = 0; i < k; i++) {
1908 final_vals[i] = -FLT_MAX;
1909 final_idx[i] = -1;
1910 }
1911
1912 for (int t = 0; t < block_size; t++) {
1913 for (int i = 0; i < k; i++) {
1914 float val = shared_vals[t * k + i];
1915 int idx = shared_idx[t * k + i];
1916
1917 if (val > final_vals[k-1]) {
1918 int pos = k - 1;
1919 while (pos > 0 && val > final_vals[pos - 1]) {
1920 pos--;
1921 }
1922
1923 for (int j = k - 1; j > pos; j--) {
1924 final_vals[j] = final_vals[j - 1];
1925 final_idx[j] = final_idx[j - 1];
1926 }
1927 final_vals[pos] = val;
1928 final_idx[pos] = idx;
1929 }
1930 }
1931 }
1932
1933 for (int i = 0; i < k; i++) {
1934 dst_idx_row[i] = final_idx[i];
1935 }
1936
1937 if (k > 1) {
1938 int32_t temp = dst_idx_row[0];
1939 dst_idx_row[0] = dst_idx_row[1];
1940 dst_idx_row[1] = temp;
1941 }
1942 }
1943 });
1944 });
1945}
1946
1947static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1948 const int nrows, queue_ptr stream) {
1949 const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
1950 const sycl::range<3> block_nums(1, nrows, 1);
1951 const size_t shared_mem = 256 * sizeof(float);
1952
1953 stream->submit([&](sycl::handler &cgh) {
1954 sycl::local_accessor<float, 1> shared_data(
1955 sycl::range<1>(shared_mem/sizeof(float)), cgh);
1956 sycl::local_accessor<int, 1> shared_indices(
1957 sycl::range<1>(shared_mem/sizeof(float)), cgh);
1958
1959 cgh.parallel_for(
1960 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1961 [=](sycl::nd_item<3> item_ct1) {
1962 const int tid = item_ct1.get_local_id(2);
1963 const int row = item_ct1.get_global_id(1);
1964
1965 float max_val = -INFINITY;
1966 int max_idx = -1;
1967
1968 for (int col = tid; col < ncols; col += 256) {
1969 float val = x[row * ncols + col];
1970 if (val > max_val) {
1971 max_val = val;
1972 max_idx = col;
1973 }
1974 }
1975
1976 shared_data[tid] = max_val;
1977 shared_indices[tid] = max_idx;
1978 item_ct1.barrier(sycl::access::fence_space::local_space);
1979
1980 for (int stride = 256/2; stride > 0; stride >>= 1) {
1981 if (tid < stride) {
1982 float val1 = shared_data[tid];
1983 float val2 = shared_data[tid + stride];
1984 if (val2 > val1) {
1985 shared_data[tid] = val2;
1986 shared_indices[tid] = shared_indices[tid + stride];
1987 }
1988 }
1989 item_ct1.barrier(sycl::access::fence_space::local_space);
1990 }
1991
1992
1993 if (tid == 0) {
1994 dst[row] = shared_indices[0];
1995 }
1996 });
1997 });
1998}
1999static void diag_mask_inf_f32_sycl(const float *x, float *dst,
2000 const int ncols_x, const int nrows_x,
2001 const int rows_per_channel, const int n_past,
2002 queue_ptr stream) {
2003 const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
2004 const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
2005 const sycl::range<3> block_nums(1, block_num_x, nrows_x);
2006 stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
2007 [=](sycl::nd_item<3> item_ct1) {
2008 diag_mask_inf_f32(x, dst, ncols_x,
2009 rows_per_channel, n_past,
2010 item_ct1);
2011 });
2012}
2013
2014static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2015 const struct ggml_tensor *src,
2016 int64_t i3, int64_t i2,
2017 int64_t i1_low, int64_t i1_high,
2018 queue_ptr stream) try {
2019
2020 dpct::memcpy_direction kind;
2021 char * src_ptr;
2022 if (ggml_backend_buffer_is_host(src->buffer)) {
2023 kind = dpct::host_to_device;
2024 //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
2025 src_ptr = (char *) src->data;
2026 // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
2027 } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
2028 // If buffer is a SYCL buffer
2029 //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
2030 kind = dpct::device_to_device;
2031 src_ptr = (char *) src->data;
2032 } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
2033 /*
2034 If buffer is a SYCL split buffer
2035 */
2036 //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
2037 GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
2038 kind = dpct::device_to_device;
2039 ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
2040 int id;
2041 SYCL_CHECK(CHECK_TRY_ERROR(
2042 id = get_current_device_id()));
2043 // GGML_SYCL_DEBUG("current device index %d\n", id);
2044 src_ptr = (char *) extra->data_device[id];
2045 } else {
2046 // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
2047 GGML_ABORT("fatal error");
2048 }
2049 char * dst_ptr = (char *) dst;
2050
2051 GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
2052 GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
2053 const enum ggml_type type = src->type;
2054 const int64_t ts = ggml_type_size(type);
2055 const int64_t bs = ggml_blck_size(type);
2056 int64_t i1_diff = i1_high - i1_low;
2057
2058 const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
2059 if (nb0 == ts && nb1 == ts*ne0/bs) {
2060 // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
2061 // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
2062 return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
2063 kind, *stream));
2064
2065 } else if (nb0 == ts) {
2066 return CHECK_TRY_ERROR(
2067 dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
2068 ts * ne0 / bs, i1_diff, kind, *stream));
2069 } else {
2070 for (int64_t i1 = 0; i1 < i1_diff; i1++) {
2071 const void * rx = (const void *) ((const char *) x + i1*nb1);
2072 void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
2073 // pretend the row is a matrix with cols=1
2074 dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2075 rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
2076 /*
2077 DPCT1001:85: The statement could not be removed.
2078 */
2079 /*
2080 DPCT1000:86: Error handling if-stmt was detected but could not be
2081 rewritten.
2082 */
2083 if (r != 0) return r;
2084 }
2085 return 0;
2086 }
2087}
2088catch (sycl::exception const &exc) {
2089 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2090 << ", line:" << __LINE__ << std::endl;
2091 std::exit(1);
2092}
2093
2094inline void ggml_sycl_op_mul_mat_sycl(
2095 ggml_backend_sycl_context & ctx,
2096 const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2097 const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2098 float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2099 const int64_t src1_ncols, const int64_t src1_padded_row_size,
2100 const queue_ptr &stream) try {
2101
2102 GGML_ASSERT(src0_dd_i != nullptr);
2103 GGML_ASSERT(src1_ddf_i != nullptr);
2104 GGML_ASSERT(dst_dd_i != nullptr);
2105
2106 const int64_t ne00 = src0->ne[0];
2107 const int64_t ne10 = src1->ne[0];
2108 GGML_ASSERT(ne00 == ne10);
2109
2110 const int64_t row_diff = row_high - row_low;
2111
2112 int id;
2113 SYCL_CHECK(
2114 CHECK_TRY_ERROR(id = get_current_device_id()));
2115
2116 const int64_t ne0 = dst->ne[0]; // used by MKL only
2117 // the main device has a larger memory buffer to hold the results from all GPUs
2118 // ldc == nrows of the matrix that cuBLAS writes into
2119 int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
2120
2121#ifdef GGML_SYCL_F16
2122 bool use_fp16 = true; // TODO(Yu) SYCL capability check
2123#else
2124 bool use_fp16 = false;
2125#endif
2126 if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
2127 row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
2128 ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2129 if (src0->type != GGML_TYPE_F16) {
2130 scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2131 " : converting src0 to fp16");
2132 const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2133 GGML_ASSERT(to_fp16_sycl != nullptr);
2134 size_t ne = row_diff*ne00;
2135 src0_as_f16.alloc(ne);
2136 to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
2137 }
2138 const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
2139 ? (const sycl::half *)src0_dd_i
2140 : src0_as_f16.get();
2141
2142 ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2143 if (src1->type != GGML_TYPE_F16) {
2144 scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2145 " : converting src1 to fp16");
2146 const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2147 GGML_ASSERT(to_fp16_sycl != nullptr);
2148 size_t ne = src1_ncols*ne10;
2149 src1_as_f16.alloc(ne);
2150 to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
2151 }
2152 const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2153 ? (const sycl::half *)src1->data + src1_padded_row_size
2154 : src1_as_f16.get();
2155
2156#if GGML_SYCL_DNNL
2157 if (!g_ggml_sycl_disable_dnn) {
2158 DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
2159 DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2160 dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2161 }
2162 else
2163#endif
2164 {
2165 ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2166
2167 const sycl::half alpha_f16 = 1.0f;
2168 const sycl::half beta_f16 = 0.0f;
2169 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2170 *stream, oneapi::mkl::transpose::trans,
2171 oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2172 &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2173 src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2174 dst_f16.get(), dpct::library_data_t::real_half, ldc,
2175 dpct::library_data_t::real_half)));
2176 scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2177 " : converting dst to fp32");
2178 const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2179 to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2180 }
2181 } else {
2182 ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2183 ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2184 if (src0->type != GGML_TYPE_F32) {
2185 scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2186 " : converting src0 to fp32");
2187 const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2188 GGML_ASSERT(to_fp32_sycl != nullptr);
2189 src0_ddq_as_f32.alloc(row_diff*ne00);
2190 to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2191 }
2192 if (src1->type != GGML_TYPE_F32) {
2193 scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2194 " : converting src1 to fp32");
2195 const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2196 GGML_ASSERT(to_fp32_sycl != nullptr);
2197 src1_ddq_as_f32.alloc(src1_ncols*ne10);
2198 to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
2199 }
2200 const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2201 const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2202
2203#if GGML_SYCL_DNNL
2204 if (!g_ggml_sycl_disable_dnn) {
2205 DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2206 DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2207 dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2208 }
2209 else
2210#endif
2211 {
2212 const float alpha = 1.0f;
2213 const float beta = 0.0f;
2214 SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2215 *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
2216 src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2217 dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2218 }
2219 }
2220 GGML_UNUSED(dst);
2221 GGML_UNUSED(src1_ddq_i);
2222 GGML_UNUSED(src1_padded_row_size);
2223}
2224catch (sycl::exception const &exc) {
2225 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2226 << ", line:" << __LINE__ << std::endl;
2227 std::exit(1);
2228}
2229
2230static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2231 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2232 GGML_ASSERT( dst->type == GGML_TYPE_F32);
2233 dpct::queue_ptr main_stream = ctx.stream();
2234 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2235 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2236 float * dst_dd = static_cast<float *>(dst->data);
2237
2238 const int32_t * opts = (const int32_t *)dst->op_params;
2239 enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
2240 const int k0 = opts[1];
2241 const int k1 = opts[2];
2242 const int s0 = opts[3];
2243 const int s1 = opts[4];
2244 const int p0 = opts[5];
2245 const int p1 = opts[6];
2246
2247 const int64_t IH = dst->src[0]->ne[1];
2248 const int64_t IW = dst->src[0]->ne[0];
2249
2250 const int64_t N = dst->ne[3];
2251 const int64_t OC = dst->ne[2];
2252 const int64_t OH = dst->ne[1];
2253 const int64_t OW = dst->ne[0];
2254
2255 const int parallel_elements = N * OC * OH * OW;
2256 const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
2257 sycl::range<3> block_nums(1, 1, num_blocks);
2258 main_stream->parallel_for(
2259 sycl::nd_range<3>(block_nums *
2260 sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
2261 sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
2262 [=](sycl::nd_item<3> item_ct1) {
2263 pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
2264 parallel_elements, src0_dd, dst_dd, op,
2265 item_ct1);
2266 });
2267}
2268
2269inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2270 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2271 GGML_ASSERT( dst->type == GGML_TYPE_F32);
2272 dpct::queue_ptr main_stream = ctx.stream();
2273 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2274 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2275 float * dst_dd = static_cast<float *>(dst->data);
2276
2277 const int64_t ne = ggml_nelements(dst->src[0]);
2278
2279 sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2280}
2281
2282inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2283 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2284 GGML_ASSERT( dst->type == GGML_TYPE_F32);
2285 dpct::queue_ptr main_stream = ctx.stream();
2286 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2287 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2288 float * dst_dd = static_cast<float *>(dst->data);
2289
2290 const int64_t ncols = dst->src[0]->ne[0];
2291 const int64_t nrows = ggml_nrows(dst->src[0]);
2292
2293 sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2294}
2295
2296inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2297 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2298 GGML_ASSERT(dst->type == GGML_TYPE_F32);
2299
2300 dpct::queue_ptr main_stream = ctx.stream();
2301 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2302
2303 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2304 float * dst_dd = static_cast<float *>(dst->data);
2305
2306 const int64_t ncols = dst->src[0]->ne[0];
2307 const int64_t nrows = ggml_nrows(dst->src[0]);
2308
2309 sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2310
2311 main_stream->parallel_for(
2312 sycl::range<1>(nrows),
2313 [=](sycl::id<1> row) {
2314 dst_dd[row] /= ncols;
2315 }
2316 );
2317}
2318
2319
2320inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2321 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2322 GGML_ASSERT(dst->type == GGML_TYPE_I32);
2323 dpct::queue_ptr main_stream = ctx.stream();
2324 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2325 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2326 int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2327
2328
2329 const int64_t ncols = dst->src[0]->ne[0];
2330 const int64_t nrows = ggml_nrows(dst->src[0]);
2331
2332 enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2333
2334 argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
2335 main_stream, ctx.device);
2336}
2337
2338static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2339 const ggml_tensor * src0 = dst->src[0];
2340
2341 GGML_ASSERT(src0);
2342 GGML_ASSERT(src0->type == GGML_TYPE_F32);
2343 GGML_ASSERT(dst->type == GGML_TYPE_I32);
2344 GGML_ASSERT(ggml_is_contiguous(src0));
2345
2346 dpct::queue_ptr main_stream = ctx.stream();
2347 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2348
2349 const float * src0_dd = static_cast<const float *>(src0->data);
2350 int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2351
2352 const int k = dst->ne[0];
2353 const int64_t ncols = src0->ne[0];
2354 const int64_t nrows = ggml_nrows(src0);
2355
2356 GGML_ASSERT(k > 0 && k <= 32);
2357 GGML_ASSERT(k <= ncols);
2358
2359 top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
2360}
2361
2362inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2363 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2364 GGML_ASSERT( dst->type == GGML_TYPE_I32);
2365
2366 dpct::queue_ptr main_stream = ctx.stream();
2367 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2368 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2369 int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2370
2371 const int64_t ncols = dst->src[0]->ne[0];
2372 const int64_t nrows = ggml_nrows(dst->src[0]);
2373
2374 argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2375}
2376
2377inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2378 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2379 GGML_ASSERT( dst->type == GGML_TYPE_F32);
2380 dpct::queue_ptr main_stream = ctx.stream();
2381 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2382 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2383 float * dst_dd = static_cast<float *>(dst->data);
2384
2385 const int64_t ne00 = dst->src[0]->ne[0];
2386 const int64_t ne01 = dst->src[0]->ne[1];
2387 const int nrows0 = ggml_nrows(dst->src[0]);
2388
2389 const int n_past = ((int32_t *) dst->op_params)[0];
2390
2391 diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2392}
2393
2394static void tri_f32_sycl(
2395 const float * src,
2396 float * dst,
2397 const int64_t ne0,
2398 const int64_t ne1,
2399 const int64_t ne2,
2400 const int64_t ne3,
2401 const ggml_tri_type ttype,
2402 dpct::queue_ptr main_stream
2403) {
2404 const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
2405
2406 main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
2407 const int64_t idx = (int64_t) tid[0];
2408
2409 const int64_t i0 = idx % ne0;
2410 const int64_t t1 = idx / ne0;
2411 const int64_t i1 = t1 % ne1;
2412
2413 bool keep = false;
2414 switch (ttype) {
2415 case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
2416 case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
2417 case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
2418 case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
2419 default: keep = false; break;
2420 }
2421
2422 dst[idx] = keep ? src[idx] : 0.0f;
2423 });
2424}
2425
2426static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2427 const ggml_tensor * src0 = dst->src[0];
2428 GGML_ASSERT(src0);
2429
2430 GGML_ASSERT(src0->type == GGML_TYPE_F32);
2431 GGML_ASSERT(dst->type == GGML_TYPE_F32);
2432 GGML_ASSERT(ggml_is_contiguous(src0));
2433 GGML_ASSERT(ggml_is_contiguous(dst));
2434 GGML_ASSERT(ggml_are_same_shape(src0, dst));
2435
2436 dpct::queue_ptr main_stream = ctx.stream();
2437 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2438
2439 const float * src0_dd = static_cast<const float *>(src0->data);
2440 float * dst_dd = static_cast<float *>(dst->data);
2441
2442 const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2443
2444 const int64_t ne0 = src0->ne[0];
2445 const int64_t ne1 = src0->ne[1];
2446 const int64_t ne2 = src0->ne[2];
2447 const int64_t ne3 = src0->ne[3];
2448
2449 tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
2450}
2451
2452
2453inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2454 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2455 GGML_ASSERT( dst->type == GGML_TYPE_F32);
2456 dpct::queue_ptr main_stream = ctx.stream();
2457 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2458 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2459 float * dst_dd = static_cast<float *>(dst->data);
2460
2461 float scale;
2462 float bias;
2463 memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
2464 memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
2465
2466 scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
2467 /*
2468 DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2469 error codes. The call was replaced with 0. You need to rewrite this code.
2470 */
2471 SYCL_CHECK(0);
2472}
2473
2474static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
2475 static bool peer_access_enabled = false;
2476
2477 const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
2478
2479 if (peer_access_enabled == enable_peer_access) {
2480 return;
2481 }
2482
2483#ifdef NDEBUG
2484 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2485 SYCL_CHECK(ggml_sycl_set_device(i));
2486 }
2487
2488 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2489 SYCL_CHECK(ggml_sycl_set_device(i));
2490
2491 for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
2492 if (i == id_other) {
2493 continue;
2494 }
2495 if (i != main_device && id_other != main_device) {
2496 continue;
2497 }
2498
2499 // int can_access_peer;
2500 // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
2501 // if (can_access_peer) {
2502 // if (enable_peer_access) {
2503 // SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
2504 // } else {
2505 // SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
2506 // }
2507 // }
2508 }
2509 }
2510#endif // NDEBUG
2511
2512 peer_access_enabled = enable_peer_access;
2513}
2514
2515template <template <int> typename quantize_f>
2516static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2517 const ggml_tensor *src1, ggml_tensor *dst,
2518 ggml_sycl_op_mul_mat_t op) try {
2519
2520 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
2521
2522 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
2523 const int64_t nrows1 = ggml_nrows(src1);
2524
2525 GGML_ASSERT(ne03 == ne13);
2526
2527 const int64_t ne0 = dst->ne[0];
2528 const int64_t ne1 = dst->ne[1];
2529
2530 const int nb2 = dst->nb[2];
2531 const int nb3 = dst->nb[3];
2532
2533 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
2534 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
2535 GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
2536
2537 GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
2538
2539 const int64_t i02_divisor = ne12 / ne02;
2540
2541 const size_t src0_ts = ggml_type_size(src0->type);
2542 const size_t src0_bs = ggml_blck_size(src0->type);
2543 const size_t q8_1_ts = sizeof(block_q8_1);
2544 const size_t q8_1_bs = QK8_1;
2545
2546 ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2547 ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2548
2549 const bool src0_is_contiguous = ggml_is_contiguous(src0);
2550 const bool src1_is_contiguous = ggml_is_contiguous(src1);
2551
2552 int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
2553
2554 const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2555 GGML_ASSERT(!(split && ne02 > 1));
2556 GGML_ASSERT(!(split && ne03 > 1));
2557 GGML_ASSERT(!(split && ne02 < ne12));
2558
2559 std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
2560 if (split) {
2561 // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
2562 // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
2563 ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2564 tensor_split = buft_ctx->tensor_split;
2565 }
2566
2567 struct dev_data {
2568 ggml_sycl_pool_alloc<char> src0_dd_alloc;
2569 ggml_sycl_pool_alloc<float> src1_ddf_alloc;
2570 ggml_sycl_pool_alloc<char> src1_ddq_alloc;
2571 ggml_sycl_pool_alloc<float> dst_dd_alloc;
2572
2573 char *src0_dd = nullptr;
2574 float *src1_ddf = nullptr; // float
2575 char *src1_ddq = nullptr; // q8_1
2576 float *dst_dd = nullptr;
2577
2578 int64_t row_low;
2579 int64_t row_high;
2580 };
2581
2582 dev_data dev[GGML_SYCL_MAX_DEVICES];
2583
2584 int used_devices = 0;
2585 queue_ptr main_stream = ctx.stream();
2586
2587 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2588 // by default, use all rows
2589 dev[i].row_low = 0;
2590 dev[i].row_high = ne01;
2591
2592 // for multi GPU, get the row boundaries from tensor split
2593 // and round to mul_mat_q tile sizes
2594 if (split) {
2595 const int64_t rounding = get_row_rounding(src0->type, tensor_split);
2596
2597 if (i != 0) {
2598 dev[i].row_low = ne01*tensor_split[i];
2599 if (dev[i].row_low < ne01) {
2600 dev[i].row_low -= dev[i].row_low % rounding;
2601 }
2602 }
2603
2604 if (i != ggml_sycl_info().device_count - 1) {
2605 dev[i].row_high = ne01*tensor_split[i + 1];
2606 if (dev[i].row_high < ne01) {
2607 dev[i].row_high -= dev[i].row_high % rounding;
2608 }
2609 }
2610 }
2611 }
2612
2613 constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2614 no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
2615 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2616 if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2617 continue;
2618 }
2619
2620 used_devices++;
2621
2622 const bool src1_on_device = i == ctx.device;
2623 const bool dst_on_device = i == ctx.device;
2624
2625 ggml_sycl_set_device(i);
2626 queue_ptr stream = ctx.stream(i, 0);
2627
2628 if (src0_is_contiguous) {
2629 dev[i].src0_dd = (char *) src0->data;
2630 } else {
2631 dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
2632 }
2633
2634 if (src1_on_device && src1_is_contiguous) {
2635 dev[i].src1_ddf = (float *) src1->data;
2636 } else {
2637 dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
2638 }
2639
2640 if constexpr(quantize_enabled) {
2641 dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2642
2643 if (src1_on_device && src1_is_contiguous) {
2644 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2645 /*num_src=*/2, " : converting src1 to Q8_1");
2646 try {
2647 quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2648 } catch (sycl::exception const &exc) {
2649 std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2650 << ", line:" << __LINE__ << std::endl;
2651 std::exit(1);
2652 }
2653 }
2654 }
2655
2656 if (dst_on_device) {
2657 dev[i].dst_dd = (float *) dst->data;
2658 } else {
2659 const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
2660 dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
2661 }
2662 }
2663
2664 // if multiple devices are used they need to wait for the main device
2665 // here an event is recorded that signals that the main device has finished calculating the input data
2666 if (split && used_devices > 1) {
2667 ggml_sycl_set_device(ctx.device);
2668 SYCL_CHECK(CHECK_TRY_ERROR(
2669 *src0_extra->events[ctx.device][0] =
2670 ctx.stream()->ext_oneapi_submit_barrier()));
2671 }
2672
2673 const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
2674 for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
2675 const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
2676 const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
2677 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2678 if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2679 continue;
2680 }
2681
2682 const bool src1_on_device = i == ctx.device;
2683 const bool dst_on_device = i == ctx.device;
2684 const int64_t row_diff = dev[i].row_high - dev[i].row_low;
2685
2686 ggml_sycl_set_device(i);
2687 queue_ptr stream = ctx.stream(i, is);
2688
2689 // wait for main GPU data if necessary
2690 if (split && (i != ctx.device || is != 0)) {
2691 SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
2692 {*src0_extra->events[ctx.device][0]})));
2693 }
2694
2695 for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
2696 const int64_t i03 = i0 / ne12;
2697 const int64_t i02 = i0 % ne12;
2698
2699 const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
2700
2701 // for split tensors the data begins at i0 == i0_offset_low
2702 char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
2703 float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
2704 char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset;
2705 float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
2706
2707 // the main device memory buffer can be on VRAM scratch, with space for all partial results
2708 // in that case an offset on dst_ddf_i is needed
2709 if (i == ctx.device) {
2710 dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
2711 }
2712
2713 // copy src0, src1 to device if necessary
2714 if (src1_is_contiguous) {
2715 if (i != ctx.device) {
2716 if constexpr (quantize_enabled) {
2717 char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2718 SYCL_CHECK(
2719 CHECK_TRY_ERROR(stream
2720 ->memcpy(src1_ddq_i, src1_ddq_i_source,
2721 src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2722 .wait()));
2723 } else {
2724 float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2725 src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
2726
2727 SYCL_CHECK(
2728 CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2729 src1_ncols * ne10 * sizeof(float))));
2730 }
2731 }
2732 } else {
2733 if (src1_on_device) {
2734 SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2735 src1_col_0 + src1_ncols, stream));
2736 } else {
2737 GGML_ABORT("src1 is non-contiguous and not on device");
2738 }
2739
2740 if constexpr (quantize_enabled) {
2741 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2742 /*num_src=*/2, " : converting src1 to Q8_1");
2743 try {
2744 quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2745 src1_padded_col_size, stream);
2746 } catch (const sycl::exception & exc) {
2747 std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2748 << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2749 std::exit(1);
2750 }
2751 }
2752 }
2753
2754 if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
2755 SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
2756 }
2757 if (src1->type == GGML_TYPE_F16) {
2758 src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
2759 }
2760 // do the computation
2761 SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
2762 dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2763
2764 // copy dst to host or other device if necessary
2765 if (!dst_on_device) {
2766 void * dst_off_device = dst->data;
2767 if (split) {
2768 // src0 = weight matrix is saved as a transposed matrix for better memory layout.
2769 // dst is NOT transposed.
2770 // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
2771 // Instead they need to be copied to the correct slice in ne0 = dst row index.
2772 // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
2773 float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
2774 GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
2775 dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
2776
2777 SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2778 dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
2779 row_diff * sizeof(float), row_diff * sizeof(float),
2780 src1_ncols, dpct::device_to_device, *stream)));
2781 } else {
2782 float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
2783 GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
2784 dhf_dst_i += src1_col_0*ne0;
2785 SYCL_CHECK(CHECK_TRY_ERROR(
2786 stream->memcpy(dhf_dst_i, dst_dd_i,
2787 src1_ncols * ne0 * sizeof(float)).wait()));
2788 }
2789 }
2790
2791 // add event for the main device to wait on until other device is done
2792 if (split && (i != ctx.device || is != 0)) {
2793 SYCL_CHECK(CHECK_TRY_ERROR(
2794 *src0_extra->events[i][is] =
2795 stream->ext_oneapi_submit_barrier()));
2796 }
2797 }
2798 }
2799 }
2800
2801 // main device waits for all other devices to be finished
2802 if (split && ggml_sycl_info().device_count > 1) {
2803 int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
2804 is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
2805
2806 ggml_sycl_set_device(ctx.device);
2807 for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2808 if (dev[i].row_low == dev[i].row_high) {
2809 continue;
2810 }
2811 for (int64_t is = 0; is < is_max; ++is) {
2812 SYCL_CHECK(CHECK_TRY_ERROR(
2813 ctx.stream()->ext_oneapi_submit_barrier(
2814 {*src0_extra->events[i][is]})));
2815 }
2816 }
2817 }
2818}
2819catch (sycl::exception const &exc) {
2820 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2821 << ", line:" << __LINE__ << std::endl;
2822 std::exit(1);
2823}
2824
2825static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2826 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2827 ggml_sycl_op_repeat_back(ctx, dst);
2828}
2829
2830static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2831 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2832 ggml_sycl_op_get_rows(ctx, dst);
2833}
2834
2835static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2836 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2837 ggml_sycl_op_norm(ctx, dst);
2838}
2839
2840static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2841 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2842 ggml_sycl_op_rms_norm(ctx, dst);
2843}
2844
2845static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2846 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2847 ggml_sycl_op_rms_norm_back(ctx, dst);
2848}
2849
2850static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2851 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2852 ggml_sycl_op_l2_norm(ctx, dst);
2853}
2854
2855static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2856 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2857 ggml_sycl_op_group_norm(ctx, dst);
2858}
2859
2860static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2861 const ggml_tensor *src1,
2862 ggml_tensor *dst) try {
2863 GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
2864 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2865 GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
2866 GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
2867 GGML_ASSERT(src0->type == GGML_TYPE_F16);
2868 GGML_ASSERT(src1->type == GGML_TYPE_F32);
2869
2870 const int64_t ne00 = src0->ne[0];
2871 const int64_t ne01 = src0->ne[1];
2872 const int64_t ne02 = src0->ne[2];
2873
2874 const int64_t ne12 = src1->ne[2];
2875
2876 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2877 queue_ptr main_stream = ctx.stream();
2878
2879 void * src0_ddq = src0->data;
2880 float * src1_ddf = (float *) src1->data;
2881 float * dst_ddf = (float *) dst->data;
2882
2883 ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
2884}
2885catch (sycl::exception const &exc) {
2886 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2887 << ", line:" << __LINE__ << std::endl;
2888 std::exit(1);
2889}
2890
2891static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2892 const ggml_tensor *src1,
2893 ggml_tensor *dst) try {
2894 GGML_ASSERT(!ggml_is_transposed(src0));
2895 GGML_ASSERT(!ggml_is_transposed(src1));
2896 GGML_ASSERT(!ggml_is_permuted(src0));
2897 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2898 GGML_ASSERT(src0->type == GGML_TYPE_F16);
2899 GGML_ASSERT(src1->type == GGML_TYPE_F32);
2900 GGML_ASSERT(src1->ne[1] == 1);
2901 GGML_ASSERT(src1->ne[3] == 1);
2902
2903 const int64_t ne00 = src0->ne[0];
2904 const int64_t ne01 = src0->ne[1];
2905 const int64_t ne02 = src0->ne[2];
2906
2907 const int64_t nb01 = src0->nb[1];
2908 const int64_t nb02 = src0->nb[2];
2909
2910 const int64_t ne12 = src1->ne[2];
2911 const int64_t nb11 = src1->nb[1];
2912
2913 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2914 queue_ptr main_stream = ctx.stream();
2915
2916 void * src0_ddq = src0->data;
2917 float * src1_ddf = (float *) src1->data;
2918 float * dst_ddf = (float *) dst->data;
2919
2920 const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2921 const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2922 const int64_t channel_stride_y = nb11 / sizeof(float);
2923
2924 ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
2925}
2926catch (sycl::exception const &exc) {
2927 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2928 << ", line:" << __LINE__ << std::endl;
2929 std::exit(1);
2930}
2931
2932static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2933 const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2934 size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2935 int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2936 const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2937 const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2938
2939 if (i13 >= ne13 || i12 >= ne12) {
2940 return;
2941 }
2942
2943 const int64_t i03 = i13 / r3;
2944 const int64_t i02 = i12 / r2;
2945
2946 const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2947 const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2948 uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
2949
2950 ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2951 ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2952 ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2953}
2954
2955static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2956 const ggml_tensor * src1, ggml_tensor * dst) try {
2957 GGML_ASSERT(!ggml_is_transposed(src0));
2958 GGML_ASSERT(!ggml_is_transposed(src1));
2959 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2960 GGML_ASSERT(src0->type == GGML_TYPE_F16);
2961 GGML_ASSERT(dst->type == GGML_TYPE_F32);
2962
2963 GGML_TENSOR_BINARY_OP_LOCALS
2964
2965 // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2966 // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2967 GGML_ASSERT(ggml_is_contiguous(dst));
2968
2969 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2970 queue_ptr queue = ctx.stream();
2971
2972 dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
2973
2974 const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2975 float * dst_ddf = static_cast<float *>(dst->data);
2976
2977 const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2978 const size_t type_size_src0 = ggml_type_size(src0->type);
2979 const size_t type_size_src1 = ggml_type_size(src1->type);
2980
2981 bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2982 bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2983
2984 // SRC1 strides
2985 int64_t s11 = nb11 / type_size_src1;
2986 int64_t s12 = nb12 / type_size_src1;
2987 int64_t s13 = nb13 / type_size_src1;
2988 ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2989
2990 // convert src1 to fp16
2991 if (src1->type != GGML_TYPE_F16) {
2992 scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2993 " : converting src1 to fp16");
2994
2995 // iterate tensor dims and find the slowest moving dim and stride
2996 int last_dim=0;
2997 int last_str=0;
2998 size_t largest_str=0;
2999 for(int i = 0; i< 4; i++){
3000 // last stride is always the largest
3001 if(src1->nb[i] == largest_str){
3002 if(src1->ne[last_dim] == 1){
3003 last_str = i;
3004 last_dim = i;
3005 }
3006 }
3007 if(src1->nb[i] > largest_str){
3008 largest_str = src1->nb[i];
3009 last_str = i;
3010 last_dim = i;
3011 }
3012
3013 }
3014#if GGML_SYCL_DNNL
3015 // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
3016 const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
3017 src1_f16_alloc.alloc(ne_src1);
3018 const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
3019 GGML_ASSERT(to_fp16_sycl != nullptr);
3020 to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
3021# else
3022 const int64_t ne_src1 = ggml_nelements(src1);
3023 src1_f16_alloc.alloc(ne_src1);
3024 const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
3025 GGML_ASSERT(to_fp16_nc_sycl != nullptr);
3026 to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
3027#endif
3028
3029 src1_f16 = src1_f16_alloc.get();
3030 s11 = ne10;
3031 s12 = ne11 * s11;
3032 s13 = ne12 * s12;
3033
3034 is_src1_cont_2 = true;
3035 }
3036
3037 ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
3038
3039 dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
3040 dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
3041
3042 // dst strides
3043 size_t nbd2 = dst->nb[2];
3044 size_t nbd3 = dst->nb[3];
3045
3046 const float alpha_f32 = 1.0f;
3047 const float beta_f32 = 0.0f;
3048
3049 const void * alpha = &alpha_f32;
3050 const void * beta = &beta_f32;
3051
3052 GGML_ASSERT(ne12 % ne02 == 0);
3053 GGML_ASSERT(ne13 % ne03 == 0);
3054 GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
3055 GGML_ASSERT(ne10 == ne00);
3056
3057 // broadcast factors
3058 const int64_t r2 = ne12 / ne02;
3059 const int64_t r3 = ne13 / ne03;
3060
3061#if GGML_SYCL_DNNL
3062 if (!g_ggml_sycl_disable_dnn) {
3063 int64_t str_a0 = nb00 / type_size_src0;
3064 int64_t str_a1 = nb01 / type_size_src0;
3065 int64_t str_a2 = nb02 / type_size_src0;
3066
3067 int64_t str_b0 = nb10 / type_size_src1;
3068 int64_t str_b1 = nb11 / type_size_src1;
3069 int64_t str_b2 = nb12 / type_size_src1;
3070
3071 auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
3072 const sycl::half *src1, float *dst,
3073 int64_t a0, int64_t a1, int64_t batcha,
3074 int64_t /*b0*/, int64_t b1, int64_t batchb,
3075 int64_t sa0, int64_t sa1, int64_t sa2,
3076 int64_t sb0, int64_t sb1, int64_t sb2,
3077 int64_t sd2) {
3078 bool supported_broadcast = batchb == batcha ? true
3079 : batchb == 1 || batcha == 1 ? true
3080 : false;
3081 if (supported_broadcast) {
3082 DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
3083 DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
3084 DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
3085 DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
3086 } else {
3087 // iterate over batches from smaller set of matrices (matrix 0)
3088 int64_t batches0 = batcha;
3089 int64_t batches1 = batchb;
3090
3091 if (batches0 > batches1) {
3092 int64_t num_mul_mats = batches1;
3093 int64_t sub_batch = batches0 / num_mul_mats;
3094 // src0 is batched and bigger, shift and multiply with src1
3095 for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
3096 const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
3097 const sycl::half *src1_shifted = src1 + (sb2 * i0);
3098 float *dst_shifted = dst + (sd2 * i0 * sub_batch);
3099 DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
3100 DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
3101 src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
3102 sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
3103 queue, sub_batch, 1);
3104 }
3105 } else {
3106 int64_t num_mul_mats = batches0;
3107 int64_t sub_batch = batches1 / num_mul_mats;
3108 // src1 is batched and bigger, shift and multiply with src0
3109 for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
3110 const sycl::half *src0_shifted = src0 + (sa2 * i1);
3111 const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
3112 float *dst_shifted = dst + (sd2 * i1 * sub_batch);
3113 DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
3114 DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
3115 src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
3116 sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
3117 queue, 1, sub_batch);
3118 }
3119 }
3120 }
3121 };
3122
3123 const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
3124 const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
3125 const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
3126 const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
3127 if (cont_batches_dim2_a && cont_batches_dim2_b) {
3128 // A batch is considered contiguous if the dimension 2 is not strided
3129 int64_t batches0 = ne02 * ne03;
3130 int64_t batches1 = ne12 * ne13;
3131 launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
3132 ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
3133 str_b2, nb2 / sizeof(float));
3134 } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
3135 // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
3136 int64_t batches0 = ne02 * ne03;
3137 int64_t batches1 = ne12 * ne13;
3138 int64_t str_a3 = nb03 / type_size_src0;
3139 int64_t str_b3 = nb13 / type_size_src1;
3140 launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
3141 ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
3142 str_b3, nb2 / sizeof(float));
3143 } else {
3144 for (int64_t b_a = 0; b_a < ne03; b_a++) {
3145 const sycl::half *src0_f16_shifted
3146 = src0_f16 + (nb03 * b_a / type_size_src0);
3147 const sycl::half *src1_f16_shifted
3148 = src1_f16 + (nb13 * b_a / type_size_src1);
3149 float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
3150 int64_t batches0 = ne02;
3151 int64_t batches1 = ne12;
3152 launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
3153 ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
3154 str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
3155 }
3156 }
3157
3158 }
3159 else
3160#endif
3161 {
3162 if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
3163 // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
3164 const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
3165 const int64_t smb = ne12 == 1 ? s13 : s12;
3166
3167 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
3168 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
3169 oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3170 src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
3171 src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
3172 mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
3173 } else {
3174 const int ne23 = ne12 * ne13;
3175
3176 ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
3177 ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
3178 ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
3179
3180 sycl::range<3> block_dims(1, ne12, ne13);
3181 queue->submit([&](sycl::handler & cgh) {
3182 const void ** ptrs_src_get = ptrs_src.get();
3183 void ** ptrs_dst_get = ptrs_dst.get();
3184 size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
3185 size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
3186 cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3187 k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
3188 nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
3189 });
3190 });
3191
3192 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3193 *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3194 (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3195 (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
3196 (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
3197 }
3198 }
3199} catch (const sycl::exception & exc) {
3200 std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3201 std::exit(1);
3202}
3203
3204enum class mul_mat_algo {
3205 DMMV = 0,
3206 MMVQ = 1,
3207 MUL_MAT_SYCL = 2,
3208};
3209
3210inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3211 // TODO: accuracy issues in MMQ
3212 GGML_UNUSED(type);
3213 return false;
3214}
3215
3216inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
3217 switch (type) {
3218 case GGML_TYPE_Q4_0:
3219 return true;
3220 case GGML_TYPE_Q4_K:
3221 case GGML_TYPE_Q6_K:
3222 return !g_ggml_sycl_prioritize_dmmv;
3223 default:
3224 return false;
3225 }
3226}
3227
3228inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
3229 switch (type) {
3230 case GGML_TYPE_Q4_0:
3231 return true;
3232 default:
3233 return false;
3234 }
3235}
3236
3237inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
3238 switch (type) {
3239 case GGML_TYPE_Q4_0:
3240 case GGML_TYPE_Q4_K:
3241 case GGML_TYPE_Q6_K:
3242 return true;
3243 default:
3244 return false;
3245 }
3246}
3247
3248static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3249 switch (type) {
3250 case GGML_TYPE_Q4_0:
3251 case GGML_TYPE_Q4_1:
3252 case GGML_TYPE_Q5_0:
3253 case GGML_TYPE_Q5_1:
3254 case GGML_TYPE_Q8_0:
3255 case GGML_TYPE_Q2_K:
3256 case GGML_TYPE_Q3_K:
3257 case GGML_TYPE_Q4_K:
3258 case GGML_TYPE_Q5_K:
3259 case GGML_TYPE_Q6_K:
3260 case GGML_TYPE_F16:
3261 return true;
3262 default:
3263 return false;
3264 }
3265}
3266
3267// Helper functions to unify device memory allocation for both async and sync paths
3268static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3269 bool use_async = g_ggml_sycl_use_async_mem_op;
3270#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3271 if (use_async) {
3272 return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
3273 }
3274#else
3275 // If async allocation extension is not available, use_async should always be false.
3276 GGML_ASSERT(!use_async);
3277#endif
3278 return sycl::malloc(size, *stream, sycl::usm::alloc::device);
3279}
3280
3281static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3282 bool use_async = g_ggml_sycl_use_async_mem_op;
3283#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3284 if (use_async) {
3285 syclex::async_free(*stream, ptr);
3286 return;
3287 }
3288#else
3289 // If async allocation extension is not available, use_async should always be false.
3290 GGML_ASSERT(!use_async);
3291#endif
3292 sycl::free(ptr, *stream);
3293}
3294
3295static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3296 dpct::queue_ptr stream) {
3297 uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3298
3299 sycl::event copy_event;
3300 SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3301 if (!g_ggml_sycl_use_async_mem_op) {
3302 copy_event.wait();
3303 }
3304
3305 GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3306 GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3307 int offset_blks = offset / sizeof(block_q4_0);
3308 auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
3309 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3310
3311 auto reorder_event = stream->parallel_for(
3312 size / sizeof(block_q4_0),
3313 [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3314 const block_q4_0* x = (const block_q4_0*)tmp_buf;
3315 const int ib = i;
3316
3317 for (int j = 0; j < QK4_0/2; j ++)
3318 {
3319 *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3320 }
3321 *(d_ptr + ib) = x[ib].d;
3322 });
3323 if (!g_ggml_sycl_use_async_mem_op) {
3324 reorder_event.wait_and_throw();
3325 }
3326 sycl_ext_free(stream, tmp_buf);
3327}
3328
3329static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3330 GGML_ASSERT(size % sizeof(block_q4_K) == 0);
3331 GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
3332
3333 const int nblocks = size / sizeof(block_q4_K);
3334
3335 uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3336
3337 sycl::event copy_event;
3338 SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3339 if (!g_ggml_sycl_use_async_mem_op) {
3340 copy_event.wait();
3341 }
3342
3343 auto * qs_ptr = data_device;
3344 auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3345 auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3346
3347 auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3348 const block_q4_K * x = (const block_q4_K *) tmp_buf;
3349 const int ib = i;
3350
3351 for (int j = 0; j < QK_K / 2; ++j) {
3352 qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
3353 }
3354
3355 for (int j = 0; j < K_SCALE_SIZE; ++j) {
3356 scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
3357 }
3358
3359 dm_ptr[ib] = x[ib].dm;
3360 });
3361 if (!g_ggml_sycl_use_async_mem_op) {
3362 reorder_event.wait_and_throw();
3363 }
3364 sycl_ext_free(stream, tmp_buf);
3365}
3366
3367static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3368 GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3369 GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3370
3371 const int nblocks = size / sizeof(block_q6_K);
3372
3373 uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3374
3375 sycl::event copy_event;
3376 SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3377 if (!g_ggml_sycl_use_async_mem_op) {
3378 copy_event.wait();
3379 }
3380
3381 auto * ql_ptr = data_device;
3382 auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3383 auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3384 sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3385
3386 auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3387 const block_q6_K * x = (const block_q6_K *) tmp_buf;
3388 const int ib = i;
3389
3390 const uint8_t * ql = x[ib].ql;
3391 const uint8_t * qh = x[ib].qh;
3392 uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3393 uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3394 uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3395
3396 for (int j = 0; j < QK_K / 2; ++j) {
3397 base_ql_ptr[j] = ql[j];
3398 }
3399 for (int j = 0; j < QK_K / 4; ++j) {
3400 base_qh_ptr[j] = qh[j];
3401 }
3402
3403 for (int j = 0; j < QK_K / 16; ++j) {
3404 base_scales_ptr[j] = x[ib].scales[j];
3405 }
3406
3407 dm_ptr[ib] = x[ib].d;
3408 });
3409 if (!g_ggml_sycl_use_async_mem_op) {
3410 reorder_event.wait_and_throw();
3411 }
3412 sycl_ext_free(stream, tmp_buf);
3413}
3414
3415static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3416 uint8_t * data_device = (uint8_t *) src0->data;
3417 size_t ncols = src0->ne[0];
3418 size_t nrows = src0->ne[1];
3419 size_t size = ggml_nbytes(src0);
3420
3421 switch (src0->type) {
3422 case GGML_TYPE_Q4_0:
3423 reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3424 break;
3425 case GGML_TYPE_Q4_K:
3426 reorder_qw_q4_k(data_device, size, 0, stream);
3427 break;
3428 case GGML_TYPE_Q6_K:
3429 reorder_qw_q6_k(data_device, size, 0, stream);
3430 break;
3431 default:
3432 GGML_ABORT("reorder_qw() called with unsupported type");
3433 break;
3434 }
3435}
3436
3437static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3438 return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3439 ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
3440 dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
3441 dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3442}
3443
3444static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3445 ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3446 if (!should_reorder_tensor(*ctx, dst)) {
3447 return;
3448 }
3449
3450 ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3451 if (!extra || extra->optimized_feature.reorder) {
3452 return; // Skip permutations and already reordered tensors
3453 }
3454
3455 switch (mm_algorithm) {
3456 case mul_mat_algo::DMMV:
3457 if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3458 return;
3459 }
3460 break;
3461 case mul_mat_algo::MMVQ:
3462 if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3463 return;
3464 }
3465 break;
3466 case mul_mat_algo::MUL_MAT_SYCL:
3467 if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3468 return;
3469 }
3470 break;
3471 }
3472
3473 reorder_qw(src0, ctx->stream());
3474 extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3475}
3476
3477
3478static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3479 return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3480 src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3481}
3482
3483static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3484 return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3485 src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3486}
3487
3488static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3489 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3490 const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3491 int64_t min_compute_capability = INT_MAX;
3492
3493 if (split) {
3494 ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3495 (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3496 auto & tensor_split = buft_ctx->tensor_split;
3497 for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
3498 // skip devices that are not going to do any work:
3499 if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
3500 continue;
3501 }
3502
3503 if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
3504 min_compute_capability = ggml_sycl_info().devices[id].cc;
3505 }
3506 }
3507 } else {
3508 min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3509 }
3510
3511 // check data types and tensor shapes for custom matrix multiplication kernels:
3512 bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
3513
3514 bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
3515
3516 bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3517 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
3518
3519
3520 // mmvq and mmq need the __dp4a instruction which is available for gen12+
3521 // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3522 use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
3523#ifdef SYCL_USE_XMX
3524 use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3525#endif // SYCL_USE_XMX
3526
3527 // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3528 // is enabled takes precedence over DMMV, the current if-else implementation
3529 // requires disabling DMMV if both conditions are met
3530 if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
3531 ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3532 use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3533 }
3534
3535 if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
3536 // TODO: Refactor and cleanup of mul mat dispatching.
3537 if (src0->ne[3] == 1 && src1->ne[3] == 1) {
3538 // KQ single-batch
3539 // mmv p021 was specific for these dimensions
3540 ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
3541 } else {
3542 // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3543 ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3544 }
3545 } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
3546 // KQV single-batch
3547 ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3548 } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
3549 // KQ + KQV multi-batch
3550 ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3551 } else if (use_dequantize_mul_mat_vec) {
3552 opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3553 ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
3554 } else if (use_mul_mat_vec_q) {
3555 opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3556 ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3557 if (extra && extra->optimized_feature.reorder) {
3558 ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3559 } else {
3560 ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3561 }
3562 } else if (use_mul_mat_q) {
3563 ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
3564 } else {
3565 ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
3566 }
3567}
3568
3569
3570struct mmid_row_mapping {
3571 int32_t i1;
3572 int32_t i2;
3573};
3574
3575__dpct_inline__ static void k_copy_src1_to_contiguous(
3576 const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3577 int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3578 const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3579 int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3580 const sycl::nd_item<3> &item_ct1, int &src1_row) {
3581 int32_t iid1 = item_ct1.get_group(2);
3582 int32_t id = item_ct1.get_group(1);
3583
3584 const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3585
3586 if (row_id_i != i02) {
3587 return;
3588 }
3589
3590 const int64_t i11 = id % ne11;
3591 const int64_t i12 = iid1;
3592
3593 if (item_ct1.get_local_id(2) == 0) {
3594 src1_row =
3595 dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3596 cur_src1_row, 1);
3597 row_mapping[src1_row] = {id, iid1};
3598 }
3599 /*
3600 DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3601 sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3602 performance if there is no access to global memory.
3603 */
3604 item_ct1.barrier();
3605
3606 const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
3607 float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
3608
3609#pragma unroll
3610 for (int i = item_ct1.get_local_id(2); i < ne10;
3611 i += item_ct1.get_local_range(2)) {
3612 src1_row_contiguous[i] = src1_row_original[i];
3613 }
3614}
3615
3616__dpct_inline__ static void k_copy_dst_from_contiguous(
3617 char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
3618 const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
3619 size_t nb2, const sycl::nd_item<3> &item_ct1) {
3620 int32_t i = item_ct1.get_group(2);
3621
3622 const int32_t i1 = row_mapping[i].i1;
3623 const int32_t i2 = row_mapping[i].i2;
3624
3625 const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
3626 float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
3627
3628#pragma unroll
3629 for (int j = item_ct1.get_local_id(2); j < ne0;
3630 j += item_ct1.get_local_range(2)) {
3631 dst_row_original[j] = dst_row_contiguous[j];
3632 }
3633}
3634
3635static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3636 ggml_tensor *dst) try {
3637 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
3638 const ggml_tensor *src0 = dst->src[0];
3639 const ggml_tensor *src1 = dst->src[1];
3640 GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
3641
3642 const ggml_tensor *ids = dst->src[2];
3643 GGML_TENSOR_BINARY_OP_LOCALS
3644
3645 const queue_ptr stream = ctx.stream();
3646
3647 const int64_t n_as = ne02;
3648 const int64_t n_ids = ids->ne[0];
3649
3650 std::vector<char> ids_host(ggml_nbytes(ids));
3651 const char * ids_dev = (const char *) ids->data;
3652
3653 SYCL_CHECK(CHECK_TRY_ERROR(
3654 stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
3655 SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
3656
3657 ggml_tensor src0_row = *src0;
3658 ggml_tensor src1_row = *src1;
3659 ggml_tensor dst_row = *dst;
3660
3661 char *src0_original = (char *)src0->data;
3662 char *src1_original = (char *)src1->data;
3663 char *dst_original = (char *)dst->data;
3664
3665 src0_row.ne[2] = 1;
3666 src0_row.ne[3] = 1;
3667 src0_row.nb[3] = nb02;
3668
3669 src1_row.ne[1] = 1;
3670 src1_row.ne[2] = 1;
3671 src1_row.ne[3] = 1;
3672 src1_row.nb[2] = nb11;
3673 src1_row.nb[3] = nb11;
3674
3675 dst_row.ne[1] = 1;
3676 dst_row.ne[2] = 1;
3677 dst_row.ne[3] = 1;
3678 dst_row.nb[2] = nb1;
3679 dst_row.nb[3] = nb1;
3680 if (ne12 == 1) {
3681 for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3682 for (int64_t id = 0; id < n_ids; id++) {
3683 const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
3684 GGML_ASSERT(i02 >= 0 && i02 < n_as);
3685
3686 const int64_t i11 = id % ne11;
3687 const int64_t i12 = iid1;
3688
3689 const int64_t i1 = id;
3690 const int64_t i2 = i12;
3691
3692 src0_row.data = src0_original + i02*nb02;
3693 src1_row.data = src1_original + i11*nb11 + i12*nb12;
3694 dst_row.data = dst_original + i1*nb1 + i2*nb2;
3695
3696 ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3697 }
3698 }
3699 } else {
3700 ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
3701 ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
3702
3703 src1_row.data = src1_contiguous.get();
3704 dst_row.data = dst_contiguous.get();
3705
3706 for (int64_t i02 = 0; i02 < n_as; i02++) {
3707 int64_t num_src1_rows = 0;
3708 for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3709 for (int64_t id = 0; id < n_ids; id++) {
3710 const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
3711
3712 GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
3713
3714 if (row_id_i != i02) {
3715 continue;
3716 }
3717
3718 num_src1_rows++;
3719 }
3720 }
3721
3722 if (num_src1_rows == 0) {
3723 continue;
3724 }
3725
3726
3727 ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
3728 ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
3729 SYCL_CHECK(CHECK_TRY_ERROR(
3730 stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
3731
3732 const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
3733 assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
3734
3735 {
3736 sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
3737 sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3738 stream->submit([&](sycl::handler &cgh) {
3739 sycl::local_accessor<int, 0> src1_row_acc(cgh);
3740
3741 char *__restrict src1_contiguous_get =
3742 src1_contiguous.get();
3743 int *__restrict dev_cur_src1_row_get =
3744 dev_cur_src1_row.get();
3745 mmid_row_mapping *__restrict dev_row_mapping_get =
3746 dev_row_mapping.get();
3747 size_t ids_nb_ct6 = ids->nb[1];
3748 size_t ids_nb_ct7 = ids->nb[0];
3749
3750 cgh.parallel_for(
3751 sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3752 [=](sycl::nd_item<3> item_ct1) {
3753 k_copy_src1_to_contiguous(
3754 src1_original, src1_contiguous_get,
3755 dev_cur_src1_row_get,
3756 dev_row_mapping_get, ids_dev, i02,
3757 ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
3758 item_ct1, src1_row_acc);
3759 });
3760 });
3761 }
3762
3763 src0_row.data = src0_original + i02*nb02;
3764
3765 GGML_ASSERT(nb11 == sizeof(float)*ne10);
3766 GGML_ASSERT(nb1 == sizeof(float)*ne0);
3767 src1_row.ne[1] = num_src1_rows;
3768
3769 src1_row.nb[1] = nb11;
3770 src1_row.nb[2] = num_src1_rows*nb11;
3771 src1_row.nb[3] = num_src1_rows*nb11;
3772
3773 dst_row.ne[1] = num_src1_rows;
3774 dst_row.nb[1] = nb1;
3775 dst_row.nb[2] = num_src1_rows*nb1;
3776 dst_row.nb[3] = num_src1_rows*nb1;
3777
3778 ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3779
3780 {
3781 sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
3782 sycl::range<3> grid_dims(1, 1, num_src1_rows);
3783 stream->submit([&](sycl::handler &cgh) {
3784 const char *__restrict dst_contiguous_get =
3785 dst_contiguous.get();
3786 const mmid_row_mapping *__restrict dev_row_mapping_get =
3787 dev_row_mapping.get();
3788
3789 cgh.parallel_for(
3790 sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3791 [=](sycl::nd_item<3> item_ct1) {
3792 k_copy_dst_from_contiguous(dst_original,
3793 dst_contiguous_get,
3794 dev_row_mapping_get,
3795 ne0, nb1, nb2, item_ct1);
3796 });
3797 });
3798 }
3799 }
3800 }
3801}
3802catch (sycl::exception const &exc) {
3803 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3804 << ", line:" << __LINE__ << std::endl;
3805 std::exit(1);
3806}
3807
3808static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3809 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3810 ggml_sycl_op_scale(ctx, dst);
3811}
3812
3813static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3814 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3815 ggml_sycl_op_diag_mask_inf(ctx, dst);
3816}
3817
3818static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3819 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3820 ggml_sycl_op_pool2d(ctx, dst);
3821}
3822
3823static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3824 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3825 ggml_sycl_op_im2col(ctx, dst);
3826}
3827
3828static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3829 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3830 GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3831 ggml_sycl_op_sum(ctx, dst);
3832}
3833
3834static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3835 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3836 GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3837 ggml_sycl_op_sum_rows(ctx, dst);
3838}
3839
3840static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3841 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3842 GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3843 ggml_sycl_op_mean(ctx, dst);
3844}
3845
3846static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3847 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3848 GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3849 ggml_sycl_op_argsort(ctx, dst);
3850}
3851
3852static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3853 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3854 GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3855 ggml_sycl_op_argmax(ctx, dst);
3856}
3857
3858
3859static void ggml_sycl_set_main_device(const int main_device) try {
3860 if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
3861 return;
3862 }
3863 check_allow_gpu_index(main_device);
3864 dpct::select_device(main_device);
3865
3866 if (g_ggml_sycl_debug) {
3867 dpct::device_info prop;
3868 SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
3869 prop, dpct::dev_mgr::instance().get_device(main_device))));
3870 GGML_LOG_INFO("Using device %d (%s) as main device\n",
3871 main_device, prop.get_name());
3872 }
3873}
3874catch (sycl::exception const &exc) {
3875 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3876 << ", line:" << __LINE__ << std::endl;
3877 std::exit(1);
3878}
3879
3880static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3881 if (!g_sycl_loaded) return false;
3882
3883 if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
3884 ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
3885 }
3886
3887 switch (dst->op) {
3888 case GGML_OP_ARGMAX:
3889 ggml_sycl_argmax(ctx, dst);
3890 break;
3891 case GGML_OP_CONV_TRANSPOSE_1D:
3892 ggml_sycl_op_conv_transpose_1d(ctx, dst);
3893 break;
3894 case GGML_OP_REPEAT:
3895 ggml_sycl_repeat(ctx, dst);
3896 break;
3897 case GGML_OP_REPEAT_BACK:
3898 ggml_sycl_repeat_back(ctx, dst);
3899 break;
3900 case GGML_OP_GET_ROWS:
3901 ggml_sycl_get_rows(ctx, dst);
3902 break;
3903 case GGML_OP_SET:
3904 ggml_sycl_op_set(ctx, dst);
3905 break;
3906 case GGML_OP_SET_ROWS:
3907 ggml_sycl_op_set_rows(ctx, dst);
3908 break;
3909 case GGML_OP_DUP:
3910 ggml_sycl_dup(ctx, dst);
3911 break;
3912 case GGML_OP_ADD:
3913 case GGML_OP_ADD1: // TODO: more efficient implementation
3914 ggml_sycl_add(ctx, dst);
3915 break;
3916 case GGML_OP_ADD_ID:
3917 ggml_sycl_add_id(ctx, dst);
3918 break;
3919 case GGML_OP_SUB:
3920 ggml_sycl_sub(ctx, dst);
3921 break;
3922 case GGML_OP_COUNT_EQUAL:
3923 ggml_sycl_count_equal(ctx, dst);
3924 break;
3925 case GGML_OP_ACC:
3926 ggml_sycl_acc(ctx, dst);
3927 break;
3928 case GGML_OP_MUL:
3929 ggml_sycl_mul(ctx, dst);
3930 break;
3931 case GGML_OP_LOG:
3932 ggml_sycl_log(ctx, dst);
3933 break;
3934 case GGML_OP_DIV:
3935 ggml_sycl_div(ctx, dst);
3936 break;
3937 case GGML_OP_UNARY:
3938 switch (ggml_get_unary_op(dst)) {
3939 case GGML_UNARY_OP_NEG:
3940 ggml_sycl_neg(ctx, dst);
3941 break;
3942 case GGML_UNARY_OP_STEP:
3943 ggml_sycl_step(ctx, dst);
3944 break;
3945 case GGML_UNARY_OP_GELU:
3946 ggml_sycl_gelu(ctx, dst);
3947 break;
3948 case GGML_UNARY_OP_SILU:
3949 ggml_sycl_silu(ctx, dst);
3950 break;
3951 case GGML_UNARY_OP_GELU_QUICK:
3952 ggml_sycl_gelu_quick(ctx, dst);
3953 break;
3954 case GGML_UNARY_OP_GELU_ERF:
3955 ggml_sycl_gelu_erf(ctx, dst);
3956 break;
3957 case GGML_UNARY_OP_TANH:
3958 ggml_sycl_tanh(ctx, dst);
3959 break;
3960 case GGML_UNARY_OP_RELU:
3961 ggml_sycl_relu(ctx, dst);
3962 break;
3963 case GGML_UNARY_OP_SIGMOID:
3964 ggml_sycl_sigmoid(ctx, dst);
3965 break;
3966 case GGML_UNARY_OP_HARDSIGMOID:
3967 ggml_sycl_hardsigmoid(ctx, dst);
3968 break;
3969 case GGML_UNARY_OP_HARDSWISH:
3970 ggml_sycl_hardswish(ctx, dst);
3971 break;
3972 case GGML_UNARY_OP_EXP:
3973 ggml_sycl_exp(ctx, dst);
3974 break;
3975 case GGML_UNARY_OP_SOFTPLUS:
3976 ggml_sycl_softplus(ctx, dst);
3977 break;
3978 case GGML_UNARY_OP_SGN:
3979 ggml_sycl_sgn(ctx, dst);
3980 break;
3981 case GGML_UNARY_OP_ABS:
3982 ggml_sycl_abs(ctx, dst);
3983 break;
3984 case GGML_UNARY_OP_ELU:
3985 ggml_sycl_elu(ctx, dst);
3986 break;
3987 case GGML_UNARY_OP_FLOOR:
3988 ggml_sycl_floor(ctx, dst);
3989 break;
3990 case GGML_UNARY_OP_CEIL:
3991 ggml_sycl_ceil(ctx, dst);
3992 break;
3993 case GGML_UNARY_OP_ROUND:
3994 ggml_sycl_round(ctx, dst);
3995 break;
3996 case GGML_UNARY_OP_TRUNC:
3997 ggml_sycl_trunc(ctx, dst);
3998 break;
3999 default:
4000 return false;
4001 }
4002 break;
4003 case GGML_OP_GLU:
4004 switch (ggml_get_glu_op(dst)) {
4005 case GGML_GLU_OP_REGLU:
4006 ggml_sycl_reglu(ctx, dst);
4007 break;
4008 case GGML_GLU_OP_GEGLU:
4009 ggml_sycl_geglu(ctx, dst);
4010 break;
4011 case GGML_GLU_OP_SWIGLU:
4012 ggml_sycl_swiglu(ctx, dst);
4013 break;
4014 case GGML_GLU_OP_SWIGLU_OAI:
4015 ggml_sycl_swiglu_oai(ctx, dst);
4016 break;
4017 case GGML_GLU_OP_GEGLU_ERF:
4018 ggml_sycl_geglu_erf(ctx, dst);
4019 break;
4020 case GGML_GLU_OP_GEGLU_QUICK:
4021 ggml_sycl_geglu_quick(ctx, dst);
4022 break;
4023 default:
4024 return false;
4025 }
4026 break;
4027 case GGML_OP_NORM:
4028 ggml_sycl_norm(ctx, dst);
4029 break;
4030 case GGML_OP_GROUP_NORM:
4031 ggml_sycl_group_norm(ctx, dst);
4032 break;
4033 case GGML_OP_CONCAT:
4034 ggml_sycl_op_concat(ctx, dst);
4035 break;
4036 case GGML_OP_PAD_REFLECT_1D:
4037 ggml_sycl_op_pad_reflect_1d(ctx,dst);
4038 break;
4039 case GGML_OP_UPSCALE:
4040 ggml_sycl_upscale(ctx, dst);
4041 break;
4042 case GGML_OP_PAD:
4043 ggml_sycl_pad(ctx, dst);
4044 break;
4045 case GGML_OP_LEAKY_RELU:
4046 ggml_sycl_leaky_relu(ctx, dst);
4047 break;
4048 case GGML_OP_RMS_NORM_BACK:
4049 ggml_sycl_rms_norm_back(ctx, dst);
4050 break;
4051 case GGML_OP_RMS_NORM:
4052 ggml_sycl_rms_norm(ctx, dst);
4053 break;
4054 case GGML_OP_L2_NORM:
4055 ggml_sycl_l2_norm(ctx, dst);
4056 break;
4057 case GGML_OP_MUL_MAT:
4058 if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
4059 return false;
4060 }
4061 /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
4062 ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
4063 break;
4064 case GGML_OP_MUL_MAT_ID:
4065 if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
4066 return false;
4067 }
4068 ggml_sycl_mul_mat_id(ctx, dst);
4069 break;
4070 case GGML_OP_OUT_PROD:
4071 ggml_sycl_op_out_prod(ctx, dst);
4072 break;
4073 case GGML_OP_SCALE:
4074 ggml_sycl_scale(ctx, dst);
4075 break;
4076 case GGML_OP_SQR:
4077 ggml_sycl_sqr(ctx, dst);
4078 break;
4079 case GGML_OP_SQRT:
4080 ggml_sycl_sqrt(ctx, dst);
4081 break;
4082 case GGML_OP_SIN:
4083 ggml_sycl_sin(ctx, dst);
4084 break;
4085 case GGML_OP_COS:
4086 ggml_sycl_cos(ctx, dst);
4087 break;
4088 case GGML_OP_CLAMP:
4089 ggml_sycl_clamp(ctx, dst);
4090 break;
4091 case GGML_OP_CPY:
4092 ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4093 break;
4094 case GGML_OP_CONT:
4095 ggml_sycl_dup(ctx, dst);
4096 break;
4097 case GGML_OP_NONE:
4098 case GGML_OP_RESHAPE:
4099 case GGML_OP_VIEW:
4100 case GGML_OP_PERMUTE:
4101 case GGML_OP_TRANSPOSE:
4102 GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
4103 break;
4104 case GGML_OP_TRI:
4105 ggml_sycl_op_tri(ctx, dst);
4106 break;
4107 case GGML_OP_DIAG_MASK_INF:
4108 ggml_sycl_diag_mask_inf(ctx, dst);
4109 break;
4110 case GGML_OP_SOFT_MAX:
4111 ggml_sycl_op_soft_max(ctx, dst);
4112 break;
4113 case GGML_OP_SOFT_MAX_BACK:
4114 ggml_sycl_op_soft_max_back(ctx, dst);
4115 break;
4116 case GGML_OP_ROPE:
4117 ggml_sycl_rope(ctx, dst);
4118 break;
4119 case GGML_OP_IM2COL:
4120 ggml_sycl_im2col(ctx, dst);
4121 break;
4122 case GGML_OP_POOL_2D:
4123 ggml_sycl_pool2d(ctx, dst);
4124 break;
4125 case GGML_OP_SUM:
4126 ggml_sycl_sum(ctx, dst);
4127 break;
4128 case GGML_OP_SUM_ROWS:
4129 ggml_sycl_sum_rows(ctx, dst);
4130 break;
4131 case GGML_OP_MEAN:
4132 ggml_sycl_mean(ctx, dst);
4133 break;
4134 case GGML_OP_ARGSORT:
4135 ggml_sycl_argsort(ctx, dst);
4136 break;
4137 case GGML_OP_TOP_K:
4138 ggml_sycl_op_top_k(ctx, dst);
4139 break;
4140 case GGML_OP_TIMESTEP_EMBEDDING:
4141 ggml_sycl_op_timestep_embedding(ctx, dst);
4142 break;
4143 case GGML_OP_RWKV_WKV6:
4144 ggml_sycl_op_rwkv_wkv6(ctx, dst);
4145 break;
4146 case GGML_OP_RWKV_WKV7:
4147 ggml_sycl_op_rwkv_wkv7(ctx, dst);
4148 break;
4149 case GGML_OP_GATED_LINEAR_ATTN:
4150 ggml_sycl_op_gated_linear_attn(ctx, dst);
4151 break;
4152 case GGML_OP_SSM_CONV:
4153 ggml_sycl_ssm_conv(ctx, dst);
4154 break;
4155 case GGML_OP_ROLL:
4156 ggml_sycl_roll(ctx, dst);
4157 break;
4158 case GGML_OP_ARANGE:
4159 ggml_sycl_arange(ctx, dst);
4160 break;
4161 default:
4162 return false;
4163 }
4164
4165 return true;
4166} catch (sycl::exception & e) {
4167 std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
4168 std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
4169 std::exit(1);
4170}
4171
4172GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
4173 size_t description_size) try {
4174 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
4175 dpct::device_info prop;
4176 SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
4177 prop, dpct::dev_mgr::instance().get_device(device))));
4178 snprintf(description, description_size, "%s", prop.get_name());
4179}
4180catch (sycl::exception const &exc) {
4181 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4182 << ", line:" << __LINE__ << std::endl;
4183 std::exit(1);
4184}
4185
4186void ggml_backend_sycl_get_device_memory(int device, size_t *free,
4187 size_t *total) try {
4188 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
4189 ggml_sycl_set_device(device);
4190
4191 SYCL_CHECK(CHECK_TRY_ERROR(
4192 dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
4193}
4194catch (sycl::exception const &exc) {
4195 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4196 << ", line:" << __LINE__ << std::endl;
4197 std::exit(1);
4198}
4199
4200////////////////////////////////////////////////////////////////////////////////
4201
4202// backend
4203
4204static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
4205
4206 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4207
4208 return sycl_ctx->name.c_str();
4209}
4210
4211static void ggml_backend_sycl_free(ggml_backend_t backend) {
4212 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4213
4214 delete sycl_ctx;
4215 delete backend;
4216}
4217
4218static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
4219 ggml_tensor *tensor,
4220 const void *data, size_t offset,
4221 size_t size) try {
4222 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4223 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
4224 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4225 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4226 ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4227
4228 GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4229 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4230 SYCL_CHECK(CHECK_TRY_ERROR(
4231 (stream)->memcpy((char *)tensor->data + offset, data, size)));
4232}
4233catch (sycl::exception const &exc) {
4234 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4235 << ", line:" << __LINE__ << std::endl;
4236 std::exit(1);
4237}
4238
4239static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
4240 const ggml_tensor *tensor,
4241 void *data, size_t offset,
4242 size_t size) try {
4243 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4244 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
4245 GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4246 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4247 ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4248
4249 GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4250 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4251 SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4252 data, (const char *)tensor->data + offset, size)));
4253}
4254catch (sycl::exception const &exc) {
4255 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4256 << ", line:" << __LINE__ << std::endl;
4257 std::exit(1);
4258}
4259
4260static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4261 const ggml_tensor *src,
4262 ggml_tensor *dst) try {
4263 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4264 bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
4265 ggml_backend_buffer_is_sycl(src->buffer);
4266 GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
4267 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
4268 GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
4269 GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
4270 if (is_cpy_supported) {
4271 /*
4272 DPCT1009:215: SYCL uses exceptions to report errors and does not use the
4273 error codes. The original code was commented out and a warning string
4274 was inserted. You need to rewrite this code.
4275 */
4276 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4277 SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4278 dst->data, src->data, ggml_nbytes(dst))));
4279 return true;
4280 }
4281
4282 return false;
4283}
4284catch (sycl::exception const &exc) {
4285 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4286 << ", line:" << __LINE__ << std::endl;
4287 std::exit(1);
4288}
4289
4290static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
4291 GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4292 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4293 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4294 SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
4295
4296 GGML_UNUSED(backend);
4297}
4298catch (sycl::exception const &exc) {
4299 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4300 << ", line:" << __LINE__ << std::endl;
4301 std::exit(1);
4302}
4303
4304static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
4305 ggml_sycl_set_main_device(sycl_ctx->device);
4306
4307 for (int i = 0; i < cgraph->n_nodes; i++) {
4308 ggml_tensor * node = cgraph->nodes[i];
4309 if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
4310 continue;
4311 }
4312 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
4313 continue;
4314 }
4315#ifndef NDEBUG
4316 assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
4317 for (int j = 0; j < GGML_MAX_SRC; j++) {
4318 if (node->src[j] != nullptr) {
4319 assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
4320 }
4321 }
4322#endif
4323 bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
4324 if (!ok) {
4325 GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
4326 }
4327 GGML_ASSERT(ok);
4328 }
4329}
4330
4331#ifdef GGML_SYCL_GRAPH
4332static bool check_graph_compatibility(ggml_cgraph * cgraph) {
4333 if (ggml_sycl_info().device_count > 1) {
4334 // A sycl_ex::command_graph object can only be created for a single device
4335 GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
4336 return false;
4337 }
4338
4339 for (int i = 0; i < cgraph->n_nodes; i++) {
4340 const ggml_op node_op = cgraph->nodes[i]->op;
4341 switch (node_op) {
4342 default:
4343 break;
4344 case GGML_OP_CONCAT:
4345 // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
4346 // but wait() can't be called on the events returned by a queue recording
4347 // to a graph.
4348 [[fallthrough]];
4349 case GGML_OP_MUL_MAT_ID:
4350 // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
4351 // submitting a memcpy operation, but wait() can't be called on a queue that
4352 // is recording to a graph.
4353 GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
4354 ggml_op_name(node_op));
4355 return false;
4356 case GGML_OP_MUL_MAT:
4357 // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4358 // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4359 // in reordering.
4360 if (!g_ggml_sycl_use_async_mem_op) {
4361 GGML_LOG_INFO(
4362 "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4363 "oneAPI async memory allocation extension "
4364 "%s\n",
4365 __func__, ggml_op_name(node_op));
4366 return false;
4367 }
4368 }
4369 }
4370 return true;
4371}
4372#endif
4373
4374static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4375 auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
4376
4377#ifdef GGML_SYCL_GRAPH
4378 bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
4379 if (use_sycl_graph) {
4380 const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
4381 if (!graph_support) {
4382 GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
4383 ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4384 return GGML_STATUS_SUCCESS;
4385 }
4386
4387 sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
4388
4389 model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
4390 ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4391 model_sycl_graph.end_recording();
4392
4393 const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
4394 if (!sycl_ctx->exec_graph || !graph_update_support) {
4395 auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
4396 model_sycl_graph.finalize();
4397 sycl_ctx->exec_graph = std::make_unique<
4398 sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4399 } else {
4400 try {
4401 sycl_ctx->exec_graph->update(model_sycl_graph);
4402 GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
4403 } catch (sycl::exception const & e) {
4404 GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
4405 auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
4406 sycl_ctx->exec_graph = std::make_unique<
4407 sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4408 }
4409 }
4410
4411 sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
4412 } else
4413#endif
4414 {
4415 ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4416 }
4417 return GGML_STATUS_SUCCESS;
4418}
4419
4420static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
4421try
4422{
4423 ggml_backend_sycl_context *sycl_ctx =
4424 (ggml_backend_sycl_context *)backend->context;
4425
4426 sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4427
4428 const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
4429 // Record the current state of the queue
4430 SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
4431}
4432catch (sycl::exception const &exc)
4433{
4434 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4435 << ", line:" << __LINE__ << std::endl;
4436 std::exit(1);
4437}
4438
4439static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
4440 GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4441 sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
4442
4443 if (ggml_backend_is_sycl(backend)) {
4444 SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
4445 } else
4446 GGML_ABORT("fatal error");
4447} catch (sycl::exception const& exc) {
4448 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4449 << ", line:" << __LINE__ << std::endl;
4450 std::exit(1);
4451}
4452
4453static ggml_backend_i ggml_backend_sycl_interface = {
4454 /* .get_name = */ ggml_backend_sycl_get_name,
4455 /* .free = */ ggml_backend_sycl_free,
4456 /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async,
4457 /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async,
4458 /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
4459 // // TODO: update for the new
4460 // interface
4461 /* .synchronize = */ ggml_backend_sycl_synchronize,
4462 /* .graph_plan_create = */ NULL,
4463 /* .graph_plan_free = */ NULL,
4464 /* .graph_plan_update = */ NULL,
4465 /* .graph_plan_compute = */ NULL,
4466 /* .graph_compute = */ ggml_backend_sycl_graph_compute,
4467 /* .event_record = */ ggml_backend_sycl_event_record,
4468 /* .event_wait = */ ggml_backend_sycl_event_wait,
4469 /* .graph_optimize = */ NULL,
4470};
4471
4472static ggml_guid_t ggml_backend_sycl_guid() {
4473 static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
4474 return &guid;
4475}
4476
4477bool ggml_backend_is_sycl(ggml_backend_t backend) {
4478 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
4479}
4480
4481int ggml_backend_sycl_get_device_count() {
4482 return ggml_sycl_info().device_count;
4483}
4484
4485
4486// backend device
4487
4488struct ggml_backend_sycl_device_context {
4489 int device;
4490 std::string name;
4491 std::string description;
4492 int op_offload_min_batch_size;
4493};
4494
4495static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
4496 ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4497 return ctx->name.c_str();
4498}
4499
4500static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
4501 ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4502 return ctx->description.c_str();
4503}
4504
4505static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
4506 ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4507 ggml_sycl_set_device(ctx->device);
4508 SYCL_CHECK(CHECK_TRY_ERROR(
4509 dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
4510}
4511
4512static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
4513 GGML_UNUSED(dev);
4514 return GGML_BACKEND_DEVICE_TYPE_GPU;
4515}
4516
4517static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
4518 props->name = ggml_backend_sycl_device_get_name(dev);
4519 props->description = ggml_backend_sycl_device_get_description(dev);
4520 props->type = ggml_backend_sycl_device_get_type(dev);
4521 ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
4522
4523 bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
4524#ifdef GGML_SYCL_NO_PEER_COPY
4525 bool events = false;
4526#else
4527 bool events = true;
4528#endif
4529
4530 props->caps = {
4531 /* .async = */ true,
4532 /* .host_buffer = */ host_buffer,
4533 /* .buffer_from_host_ptr = */ false,
4534 /* .events = */ events,
4535 };
4536}
4537
4538static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
4539 GGML_UNUSED(params);
4540 ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4541 return ggml_backend_sycl_init(ctx->device);
4542}
4543
4544static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
4545 ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
4546 return ggml_backend_sycl_buffer_type(ctx->device);
4547}
4548
4549static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
4550 GGML_UNUSED(dev);
4551 return ggml_backend_sycl_host_buffer_type();
4552}
4553
4554static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
4555 GGML_UNUSED(dev);
4556 GGML_UNUSED(ptr);
4557 GGML_UNUSED(size);
4558 GGML_UNUSED(max_tensor_size);
4559 return nullptr;
4560}
4561
4562static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4563 ggml_backend_sycl_device_context *sycl_ctx =
4564 (ggml_backend_sycl_device_context *)dev->context;
4565 int device = sycl_ctx->device;
4566 switch (op->op) {
4567 case GGML_OP_CONV_TRANSPOSE_1D:
4568 {
4569 ggml_type src0_type = op->src[0]->type;
4570 ggml_type src1_type = op->src[1]->type;
4571 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4572 return true;
4573 }
4574 return false;
4575 }
4576 case GGML_OP_UNARY:
4577 switch (ggml_get_unary_op(op)) {
4578 case GGML_UNARY_OP_SGN:
4579 case GGML_UNARY_OP_ABS:
4580 case GGML_UNARY_OP_NEG:
4581 case GGML_UNARY_OP_STEP:
4582 case GGML_UNARY_OP_RELU:
4583 case GGML_UNARY_OP_HARDSIGMOID:
4584 case GGML_UNARY_OP_TANH:
4585 case GGML_UNARY_OP_GELU:
4586 case GGML_UNARY_OP_SILU:
4587 case GGML_UNARY_OP_SIGMOID:
4588 case GGML_UNARY_OP_HARDSWISH:
4589 case GGML_UNARY_OP_GELU_QUICK:
4590 case GGML_UNARY_OP_GELU_ERF:
4591 case GGML_UNARY_OP_EXP:
4592 case GGML_UNARY_OP_SOFTPLUS:
4593 case GGML_UNARY_OP_ELU:
4594 case GGML_UNARY_OP_CEIL:
4595 return true;
4596 case GGML_UNARY_OP_FLOOR:
4597 case GGML_UNARY_OP_ROUND:
4598 case GGML_UNARY_OP_TRUNC:
4599#if defined (GGML_SYCL_F16)
4600 return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4601#else
4602 return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4603#endif
4604 default:
4605 return false;
4606 }
4607 case GGML_OP_GLU:
4608 switch (ggml_get_glu_op(op)) {
4609 case GGML_GLU_OP_REGLU:
4610 case GGML_GLU_OP_GEGLU:
4611 case GGML_GLU_OP_SWIGLU:
4612 case GGML_GLU_OP_SWIGLU_OAI:
4613 case GGML_GLU_OP_GEGLU_ERF:
4614 case GGML_GLU_OP_GEGLU_QUICK:
4615 return ggml_is_contiguous_1(op->src[0]);
4616 default:
4617 return false;
4618 }
4619 break;
4620 case GGML_OP_MUL_MAT:
4621 case GGML_OP_MUL_MAT_ID:
4622 {
4623 struct ggml_tensor * a = op->src[0];
4624 struct ggml_tensor * b = op->src[1];
4625
4626 if (a->ne[3] != b->ne[3]) {
4627 return false;
4628 }
4629 ggml_type a_type = a->type;
4630 if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
4631 a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
4632 a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
4633 a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
4634 ) {
4635 if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
4636 return false;
4637 }
4638 }
4639 ggml_type src0_type = op->src[0]->type;
4640 if (src0_type == GGML_TYPE_BF16 ) {
4641 // TODO: support GGML_TYPE_BF16
4642 // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4643 return false;
4644 }
4645
4646 // TODO: The configuration below needs more work to be supported with oneDNN
4647 if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
4648 a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
4649 return false;
4650 }
4651
4652 // TODO: This specific configuration can fail with oneDNN and needs more debugging
4653 if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4654 a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
4655 return false;
4656 }
4657 return true;
4658 }
4659 case GGML_OP_OUT_PROD:
4660 return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4661 case GGML_OP_GET_ROWS:
4662 {
4663 switch (op->src[0]->type) {
4664 case GGML_TYPE_F16:
4665 case GGML_TYPE_F32:
4666 case GGML_TYPE_Q4_0:
4667 case GGML_TYPE_Q4_1:
4668 case GGML_TYPE_Q5_0:
4669 case GGML_TYPE_Q5_1:
4670 case GGML_TYPE_Q8_0:
4671 return true;
4672 default:
4673 return false;
4674 }
4675 }
4676 case GGML_OP_SET:
4677 return (op->type == GGML_TYPE_F32) &&
4678 (op->src[0] && op->src[1]) &&
4679 (op->src[0]->type == GGML_TYPE_F32) &&
4680 (op->src[1]->type == GGML_TYPE_F32);
4681
4682 case GGML_OP_SET_ROWS:
4683 {
4684 return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4685 op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
4686 op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
4687 (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
4688 }
4689 break;
4690 case GGML_OP_CPY:
4691 {
4692 ggml_type src0_type = op->src[0]->type;
4693 ggml_type src1_type = op->src[1]->type;
4694 if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
4695 return true;
4696 }
4697 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4698 return true;
4699 }
4700 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
4701 return true;
4702 }
4703 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
4704 return true;
4705 }
4706 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
4707 return true;
4708 }
4709 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
4710 return true;
4711 }
4712 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
4713 return true;
4714 }
4715 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4716 return true;
4717 }
4718 if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4719 return true;
4720 }
4721 if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4722 return true;
4723 }
4724 if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4725 return true;
4726 }
4727 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4728 return true;
4729 }
4730 if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4731 return true;
4732 }
4733 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4734 return true;
4735 }
4736 if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4737 return true;
4738 }
4739 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4740 return true;
4741 }
4742 if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
4743 return true;
4744 }
4745 if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
4746 return true;
4747 }
4748 if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
4749 return true;
4750 }
4751 if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
4752 return true;
4753 }
4754 if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
4755 return true;
4756 }
4757 return false;
4758 }
4759 case GGML_OP_REPEAT_BACK:
4760 {
4761 ggml_type src0_type = op->src[0]->type;
4762 return src0_type == GGML_TYPE_F32;
4763 }
4764 case GGML_OP_CONCAT:
4765 case GGML_OP_DUP:
4766 case GGML_OP_ARGMAX:
4767 case GGML_OP_NONE:
4768 case GGML_OP_RESHAPE:
4769 case GGML_OP_VIEW:
4770 case GGML_OP_PERMUTE:
4771 case GGML_OP_TRANSPOSE:
4772 case GGML_OP_ADD:
4773 case GGML_OP_ADD1:
4774 case GGML_OP_ADD_ID:
4775 case GGML_OP_SUB:
4776 case GGML_OP_COUNT_EQUAL:
4777 case GGML_OP_MUL:
4778 case GGML_OP_DIV:
4779 case GGML_OP_REPEAT:
4780 return true;
4781 case GGML_OP_PAD_REFLECT_1D:
4782 return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
4783 case GGML_OP_SQR:
4784 case GGML_OP_SQRT:
4785 case GGML_OP_SIN:
4786 case GGML_OP_COS:
4787 case GGML_OP_CLAMP:
4788 case GGML_OP_LOG:
4789#if defined (GGML_SYCL_F16)
4790 return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
4791#else
4792 return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4793#endif
4794 case GGML_OP_NORM:
4795 case GGML_OP_L2_NORM:
4796 case GGML_OP_GROUP_NORM:
4797 case GGML_OP_RMS_NORM:
4798 return true;
4799 case GGML_OP_RMS_NORM_BACK:
4800 return ggml_is_contiguous(op->src[0]);
4801 case GGML_OP_SCALE:
4802 return true;
4803 case GGML_OP_CONT:
4804 return op->src[0]->type != GGML_TYPE_BF16;
4805 case GGML_OP_TRI:
4806 {
4807 const ggml_tensor * src0 = op->src[0];
4808 return src0 &&
4809 op->type == GGML_TYPE_F32 &&
4810 ggml_is_contiguous(src0);
4811 }
4812 case GGML_OP_DIAG_MASK_INF:
4813 return true;
4814 case GGML_OP_SOFT_MAX:
4815 return true;
4816 case GGML_OP_SOFT_MAX_BACK: {
4817 float max_bias = 0.0f;
4818 memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4819 return max_bias == 0.0f;
4820 }
4821 case GGML_OP_ROPE:
4822 case GGML_OP_IM2COL:
4823 return true;
4824 case GGML_OP_UPSCALE:
4825 return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4826 case GGML_OP_SUM:
4827 case GGML_OP_SUM_ROWS:
4828 case GGML_OP_MEAN:
4829 return ggml_is_contiguous(op->src[0]);
4830 case GGML_OP_ARGSORT:
4831 return op->src[0]->ne[0] * sizeof(int) <=
4832 ggml_sycl_info().devices[device].smpbo;
4833 case GGML_OP_TOP_K: {
4834 const ggml_tensor * src0 = op->src[0];
4835 const int k = op->ne[0];
4836 return src0 &&
4837 op->type == GGML_TYPE_I32 &&
4838 src0->type == GGML_TYPE_F32 &&
4839 ggml_is_contiguous(src0) &&
4840 k > 0 && k <= 32;
4841 }
4842 case GGML_OP_POOL_2D:
4843 case GGML_OP_ACC:
4844 return true;
4845 case GGML_OP_PAD:
4846 // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4847 if (ggml_get_op_params_i32(op, 8) != 0) {
4848 return false;
4849 }
4850 return ggml_is_contiguous(op->src[0]);
4851 case GGML_OP_LEAKY_RELU:
4852 case GGML_OP_TIMESTEP_EMBEDDING:
4853 case GGML_OP_RWKV_WKV6:
4854 case GGML_OP_RWKV_WKV7:
4855 case GGML_OP_GATED_LINEAR_ATTN:
4856 return true;
4857 case GGML_OP_SSM_CONV:
4858 return op->type == GGML_TYPE_F32 &&
4859 op->src[0]->type == GGML_TYPE_F32 &&
4860 op->src[1]->type == GGML_TYPE_F32;
4861 case GGML_OP_ROLL:
4862 return op->type == GGML_TYPE_F32;
4863 case GGML_OP_ARANGE:
4864 return op->type == GGML_TYPE_F32;
4865 default:
4866 return false;
4867 }
4868
4869 GGML_UNUSED(dev);
4870}
4871
4872static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
4873 if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
4874 return false;
4875 }
4876 ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
4877 ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4878 return buft_ctx->device == sycl_ctx->device;
4879}
4880
4881static int64_t get_op_batch_size(const ggml_tensor * op) {
4882 switch (op->op) {
4883 case GGML_OP_GET_ROWS:
4884 return 0;
4885 case GGML_OP_MUL_MAT:
4886 return op->ne[1];
4887 case GGML_OP_MUL_MAT_ID:
4888 case GGML_OP_ROPE:
4889 return op->ne[2];
4890 default:
4891 return ggml_nrows(op);
4892 }
4893}
4894
4895static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4896 ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4897 return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
4898}
4899
4900static ggml_backend_event_t
4901ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
4902
4903#ifdef GGML_SYCL_NO_PEER_COPY
4904 return nullptr;
4905#else
4906 sycl::event *event_ptr = new sycl::event();
4907
4908 return new ggml_backend_event{
4909 /* .device = */ dev,
4910 /* .context = */ event_ptr,
4911 };
4912#endif
4913}
4914
4915static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4916 GGML_UNUSED(dev);
4917 if (event == nullptr) {
4918 return;
4919 }
4920
4921 if (event->context != nullptr) {
4922 sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4923 delete sycl_event;
4924 event->context = nullptr;
4925 }
4926
4927 delete event;
4928} catch (sycl::exception const &exc) {
4929 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4930 << ", line:" << __LINE__ << std::endl;
4931 std::exit(1);
4932}
4933
4934
4935static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4936 GGML_UNUSED(dev);
4937 GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4938
4939 sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4940 SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
4941} catch (sycl::exception const &exc) {
4942 std::cerr << exc.what() << "Exception caught at file:" << __FILE__
4943 << ", line:" << __LINE__ << std::endl;
4944 std::exit(1);
4945}
4946
4947static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
4948 /* .get_name = */ ggml_backend_sycl_device_get_name,
4949 /* .get_description = */ ggml_backend_sycl_device_get_description,
4950 /* .get_memory = */ ggml_backend_sycl_device_get_memory,
4951 /* .get_type = */ ggml_backend_sycl_device_get_type,
4952 /* .get_props = */ ggml_backend_sycl_device_get_props,
4953 /* .init_backend = */ ggml_backend_sycl_device_init,
4954 /* .get_buffer_type = */ ggml_backend_sycl_device_get_buffer_type,
4955 /* .get_host_buffer_type = */ ggml_backend_sycl_device_get_host_buffer_type,
4956 /* .buffer_from_host_ptr = */ ggml_backend_sycl_device_buffer_from_host_ptr,
4957 /* .supports_op = */ ggml_backend_sycl_device_supports_op,
4958 /* .supports_buft = */ ggml_backend_sycl_device_supports_buft,
4959 /* .offload_op = */ ggml_backend_sycl_device_offload_op,
4960 /* .event_new = */ ggml_backend_sycl_device_event_new,
4961 /* .event_free = */ ggml_backend_sycl_device_event_free,
4962 /* .event_synchronize = */ ggml_backend_sycl_device_event_synchronize,
4963};
4964
4965// backend reg
4966
4967struct ggml_backend_sycl_reg_context {
4968 std::vector<ggml_backend_dev_t> devices;
4969};
4970
4971static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
4972 GGML_UNUSED(reg);
4973 return GGML_SYCL_NAME;
4974}
4975
4976static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
4977 ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
4978 return ctx->devices.size();
4979}
4980
4981static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
4982 ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
4983 GGML_ASSERT(index < ctx->devices.size());
4984 return ctx->devices[index];
4985}
4986
4987static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
4988 GGML_UNUSED(reg);
4989
4990 if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4991 return (void *)ggml_backend_sycl_split_buffer_type;
4992 }
4993
4994 // SYCL doesn't support registering host memory, left here for reference
4995 // "ggml_backend_register_host_buffer"
4996 // "ggml_backend_unregister_host_buffer"
4997 GGML_UNUSED(name);
4998 return nullptr;
4999}
5000
5001static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
5002 /* .get_name = */ ggml_backend_sycl_reg_get_name,
5003 /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count,
5004 /* .get_device = */ ggml_backend_sycl_reg_get_device,
5005 /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address,
5006};
5007
5008
5009// backend registry
5010
5011ggml_backend_reg_t ggml_backend_sycl_reg() {
5012 static ggml_backend_reg reg;
5013 static bool initialized = false;
5014
5015 {
5016 static std::mutex mutex;
5017 std::lock_guard<std::mutex> lock(mutex);
5018 if (!initialized) {
5019 ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
5020 const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
5021
5022 for (int i = 0; i < ggml_sycl_info().device_count; i++) {
5023 ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
5024 dev_ctx->device = i;
5025 dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
5026
5027 ggml_sycl_set_device(i);
5028
5029 dpct::device_info prop;
5030 SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
5031 prop, dpct::dev_mgr::instance().get_device(i))));
5032
5033 dev_ctx->description = prop.get_name();
5034 dev_ctx->op_offload_min_batch_size = min_batch_size;
5035
5036 ggml_backend_dev_t dev = new ggml_backend_device {
5037 /* .iface = */ ggml_backend_sycl_device_interface,
5038 /* .reg = */ ®,
5039 /* .context = */ dev_ctx
5040 };
5041 ctx->devices.push_back(dev);
5042 }
5043
5044 reg = ggml_backend_reg {
5045 /* .api_version = */ GGML_BACKEND_API_VERSION,
5046 /* .iface = */ ggml_backend_sycl_reg_interface,
5047 /* .context = */ ctx
5048 };
5049 }
5050
5051 initialized = true;
5052 }
5053
5054 return ®
5055}
5056
5057ggml_backend_t ggml_backend_sycl_init(int device) {
5058 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
5059 ggml_check_sycl();
5060
5061 check_allow_gpu_index(device);
5062
5063 ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
5064 if (ctx == nullptr) {
5065 GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
5066 return nullptr;
5067 };
5068
5069 ggml_backend_t sycl_backend = new ggml_backend {
5070 /* .guid = */ ggml_backend_sycl_guid(),
5071 /* .iface = */ ggml_backend_sycl_interface,
5072 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
5073 /* .context = */ ctx
5074 };
5075
5076 return sycl_backend;
5077}
5078
5079GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg)