1#include "ggml.h"
 2#include "utils.hpp"
 3
 4zdnn_data_types ggml_zdnn_type_mapping(ggml_type type) {
 5    switch (type) {
 6        case GGML_TYPE_F32:
 7            return FP32;
 8        case GGML_TYPE_F16:
 9            return FP16;
10        case GGML_TYPE_BF16:
11            return BFLOAT;
12        case GGML_TYPE_Q8_0:
13            return INT8;
14        case GGML_TYPE_I8:
15            return INT8;
16        case GGML_TYPE_I32:
17            return INT32;
18        default:
19            GGML_ABORT("%s: fatal: unable to determine zTensor data type",
20                       __func__);
21            break;
22    }
23}
24
25void ggml_zdnn_create_tensor(zdnn_tensor_desc  & pre_tfm_desc,
26                             zdnn_tensor_desc  & tfm_desc,
27                             zdnn_ztensor      & ztensor,
28                       const ggml_tensor       * src,
29                       const int64_t           * ne,
30                       const zdnn_data_layouts   layout) {
31    zdnn_init_pre_transformed_desc(
32        layout,
33        ggml_zdnn_type_mapping(src->type),
34        &pre_tfm_desc,
35        ne[3], ne[2], ne[1], ne[0]
36    );
37
38    ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc, &tfm_desc));
39    ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor));
40}
41
42void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, void * buffer) {
43    ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer));
44}
45
46void ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor) {
47    switch (tensor->op) {
48        case GGML_OP_MUL_MAT:
49            {
50                zdnn_init_pre_transformed_desc(
51                    ZDNN_2D,
52                    ggml_zdnn_type_mapping(tensor->type),
53                    &buffer->pre_tfm_desc,
54                    tensor->ne[1], tensor->ne[0]
55                );
56            } break;
57
58        default:
59            {
60                // For 4D tensors, GGML uses NCHW layout. However, because zDNN
61                // automatically transforms everything to NHWC, we will use it
62                // directly to avoid the performance penalty changing the
63                // layout and reshaping the tensor.
64                zdnn_init_pre_transformed_desc(
65                    ZDNN_NHWC,
66                    ggml_zdnn_type_mapping(tensor->type),
67                    &buffer->pre_tfm_desc,
68                    tensor->ne[3], tensor->ne[2], tensor->ne[1], tensor->ne[0]
69                );
70
71                // TODO: Consider adding a ggml check.
72                // TODO: If tensor = 4D, use ZDNN_NCHW by default.
73                // TODO: If tensor = 2D, use ZDNN_NHWC by default.
74            } break;
75    }
76
77    ZDNN_CHECK(zdnn_generate_transformed_desc(&buffer->pre_tfm_desc, &buffer->tfm_desc));
78    ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&buffer->pre_tfm_desc, &buffer->tfm_desc, &buffer->ztensor));
79}