1#include "outprod.hpp"
 2
 3void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 4    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
 5    const ggml_tensor *src0 = dst->src[0];
 6    const ggml_tensor *src1 = dst->src[1];
 7
 8    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 9    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10    GGML_ASSERT(dst->type == GGML_TYPE_F32);
11    GGML_ASSERT(ggml_is_contiguous(src0));
12    GGML_ASSERT(ggml_is_contiguous(dst));
13
14    GGML_TENSOR_BINARY_OP_LOCALS
15
16    // Get SYCL queue
17    dpct::queue_ptr stream = ctx.stream();
18
19    // Dimension checks
20    GGML_ASSERT(ne01 == ne11);  // Inner dimensions must match
21    GGML_ASSERT(ne0 == ne00);   // Output rows match src0 rows
22    GGML_ASSERT(ne1 == ne10);   // Output cols match src1 cols
23
24    // Get data pointers
25    const float* src0_d = (const float*)src0->data;
26    const float* src1_d = (const float*)src1->data;
27    float* dst_d = (float*)dst->data;
28
29    // GEMM parameters
30    const float alpha = 1.0f;
31    const float beta = 0.0f;
32
33    // Handle transposition of src1
34    const bool src1_T = ggml_is_transposed(src1);
35    const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
36    const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
37
38    try {
39        // Perform matrix multiplication using oneMKL GEMM
40        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,
41                                               ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
42    }
43    catch (sycl::exception const& exc) {
44        std::cerr << exc.what() << std::endl;
45        GGML_ASSERT(false);
46    }
47}