diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/gemm.hpp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/gemm.hpp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-sycl/gemm.hpp | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/gemm.hpp b/llama.cpp/ggml/src/ggml-sycl/gemm.hpp new file mode 100644 index 0000000..dcf6c7a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-sycl/gemm.hpp @@ -0,0 +1,90 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GEMM_HPP +#define GGML_SYCL_GEMM_HPP + +#include "ggml-sycl.h" + +#if GGML_SYCL_DNNL + +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" + +class DnnlGemmWrapper { +public: + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + + template<typename T> + static constexpr dt to_dt() { + if constexpr (std::is_same_v<T, float>) return dt::f32; + else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; + else static_assert(0); + } + + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2, + const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { + + auto stream = ctx.stream_dnnl(q); + auto eng = ctx.engine_dnnl(q); + + dnnl::memory::dims a_dims = {batches_a, m, k }; + dnnl::memory::dims a_strides = {stra2, stra1, stra0}; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + + dnnl::memory::dims b_dims = {batches_b, k, n }; + dnnl::memory::dims b_strides = {strb2, strb0, strb1}; + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n}; + dnnl::memory::dims c_strides = {m*n, 1, m }; + const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides); + dnnl::primitive_attr primitive_attr; + primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + +#ifdef GGML_SYCL_F16 + primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); +#endif + + auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a)); + auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b)); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + auto scratchpad_md = matmul_pd.scratchpad_desc(); + auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); + + auto matmul_prim = dnnl::matmul(matmul_pd); + + std::unordered_map<int, dnnl::memory> matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); + + matmul_prim.execute(stream, matmul_args); + } + + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + } +}; + +#endif + +#endif // GGML_SYCL_GEMM_HPP |
