summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-sycl/gemm.hpp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/gemm.hpp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/gemm.hpp')
-rw-r--r--llama.cpp/ggml/src/ggml-sycl/gemm.hpp90
1 files changed, 90 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/gemm.hpp b/llama.cpp/ggml/src/ggml-sycl/gemm.hpp
new file mode 100644
index 0000000..dcf6c7a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-sycl/gemm.hpp
@@ -0,0 +1,90 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_GEMM_HPP
+#define GGML_SYCL_GEMM_HPP
+
+#include "ggml-sycl.h"
+
+#if GGML_SYCL_DNNL
+
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+
+class DnnlGemmWrapper {
+public:
+ using dt = dnnl::memory::data_type;
+ using tag = dnnl::memory::format_tag;
+
+ template<typename T>
+ static constexpr dt to_dt() {
+ if constexpr (std::is_same_v<T, float>) return dt::f32;
+ else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
+ else static_assert(0);
+ }
+
+ static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+ const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
+ const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
+
+ auto stream = ctx.stream_dnnl(q);
+ auto eng = ctx.engine_dnnl(q);
+
+ dnnl::memory::dims a_dims = {batches_a, m, k };
+ dnnl::memory::dims a_strides = {stra2, stra1, stra0};
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
+
+ dnnl::memory::dims b_dims = {batches_b, k, n };
+ dnnl::memory::dims b_strides = {strb2, strb0, strb1};
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
+
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
+ dnnl::memory::dims c_strides = {m*n, 1, m };
+ const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
+ dnnl::primitive_attr primitive_attr;
+ primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+
+#ifdef GGML_SYCL_F16
+ primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
+#endif
+
+ auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
+ auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+ auto scratchpad_md = matmul_pd.scratchpad_desc();
+ auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
+
+ auto matmul_prim = dnnl::matmul(matmul_pd);
+
+ std::unordered_map<int, dnnl::memory> matmul_args;
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
+ matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
+
+ matmul_prim.execute(stream, matmul_args);
+ }
+
+ static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+
+ gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
+ }
+};
+
+#endif
+
+#endif // GGML_SYCL_GEMM_HPP