1#include "ggml.h"
 2#include "mmf.hpp"
 3
 4void ggml_zdnn_mul_mat_f(
 5    const ggml_backend_zdnn_context * ctx,
 6    const               ggml_tensor * src0,
 7    const               ggml_tensor * src1,
 8                        ggml_tensor * dst) {
 9    GGML_TENSOR_BINARY_OP_LOCALS;
10
11    const enum ggml_type type = src0->type;
12
13    GGML_ASSERT(ne0 == ne01);
14    GGML_ASSERT(ne1 == ne11);
15    GGML_ASSERT(ne2 == ne12);
16    GGML_ASSERT(ne3 == ne13);
17
18    // we don't support permuted src0 or src1
19    GGML_ASSERT(nb00 == ggml_type_size(type));
20    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
21
22    // dst cannot be transposed or permuted
23    GGML_ASSERT(nb0 == sizeof(float));
24    GGML_ASSERT(nb0 <= nb1);
25    GGML_ASSERT(nb1 <= nb2);
26    GGML_ASSERT(nb2 <= nb3);
27
28    const ggml_tensor * weights = src0;
29    const ggml_tensor * inputs  = src1;
30          ggml_tensor * output  = dst;
31
32    ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra;
33    ggml_backend_zdnn_buffer * inputs_extra  = (ggml_backend_zdnn_buffer *)inputs->extra;
34    ggml_backend_zdnn_buffer * output_extra  = (ggml_backend_zdnn_buffer *)output->extra;
35    ggml_backend_zdnn_buffer * bias_extra    = (ggml_backend_zdnn_buffer *)output_extra->extra;
36
37    const int64_t weights_rows = ne01;
38    const int64_t weights_cols = ne00;
39    const int64_t inputs_rows  = ne11;
40    const int64_t inputs_cols  = ne10;
41
42    assert(inputs_cols == weights_cols);
43
44    const int64_t output_rows = ne1;
45    const int64_t output_cols = ne0;
46
47    // GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
48    //               __func__, weights_extra->name,
49    //               weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0],
50    //               weights_extra->pre_tfm_desc.dim1,
51    //               weights_extra->pre_tfm_desc.dim2,
52    //               weights_extra->pre_tfm_desc.dim3,
53    //               weights_extra->pre_tfm_desc.dim4);
54
55    // GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
56    //               __func__, inputs_extra->name,
57    //               inputs->ne[3], inputs->ne[2], inputs->ne[1], inputs->ne[0],
58    //               inputs_extra->pre_tfm_desc.dim1,
59    //               inputs_extra->pre_tfm_desc.dim2,
60    //               inputs_extra->pre_tfm_desc.dim3,
61    //               inputs_extra->pre_tfm_desc.dim4);
62
63    GGML_ASSERT(weights_extra->pre_tfm_desc.dim1 == weights->ne[0] && "weights_extra->pre_tfm_desc.dim1 must match weights->ne[0]");
64    GGML_ASSERT(weights_extra->pre_tfm_desc.dim2 == weights->ne[1] && "weights_extra->pre_tfm_desc.dim2 must match weights->ne[1]");
65    GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1  == inputs->ne[0]  && "inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]");
66    GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2  == inputs->ne[1]  && "inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]");
67
68    ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor,
69                                        false, true, MATMUL_OP_ADDITION, &output_extra->ztensor));
70    // TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient.
71    ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data));
72
73    GGML_UNUSED(ctx);
74    GGML_UNUSED(weights_rows);
75    GGML_UNUSED(weights_cols);
76    GGML_UNUSED(inputs_rows);
77    GGML_UNUSED(inputs_cols);
78    GGML_UNUSED(output_rows);
79    GGML_UNUSED(output_cols);
80}