1#include "ggml.h"
2#include "utils.hpp"
3
4zdnn_data_types ggml_zdnn_type_mapping(ggml_type type) {
5 switch (type) {
6 case GGML_TYPE_F32:
7 return FP32;
8 case GGML_TYPE_F16:
9 return FP16;
10 case GGML_TYPE_BF16:
11 return BFLOAT;
12 case GGML_TYPE_Q8_0:
13 return INT8;
14 case GGML_TYPE_I8:
15 return INT8;
16 case GGML_TYPE_I32:
17 return INT32;
18 default:
19 GGML_ABORT("%s: fatal: unable to determine zTensor data type",
20 __func__);
21 break;
22 }
23}
24
25void ggml_zdnn_create_tensor(zdnn_tensor_desc & pre_tfm_desc,
26 zdnn_tensor_desc & tfm_desc,
27 zdnn_ztensor & ztensor,
28 const ggml_tensor * src,
29 const int64_t * ne,
30 const zdnn_data_layouts layout) {
31 zdnn_init_pre_transformed_desc(
32 layout,
33 ggml_zdnn_type_mapping(src->type),
34 &pre_tfm_desc,
35 ne[3], ne[2], ne[1], ne[0]
36 );
37
38 ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc, &tfm_desc));
39 ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor));
40}
41
42void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, void * buffer) {
43 ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer));
44}
45
46void ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor) {
47 switch (tensor->op) {
48 case GGML_OP_MUL_MAT:
49 {
50 zdnn_init_pre_transformed_desc(
51 ZDNN_2D,
52 ggml_zdnn_type_mapping(tensor->type),
53 &buffer->pre_tfm_desc,
54 tensor->ne[1], tensor->ne[0]
55 );
56 } break;
57
58 default:
59 {
60 // For 4D tensors, GGML uses NCHW layout. However, because zDNN
61 // automatically transforms everything to NHWC, we will use it
62 // directly to avoid the performance penalty changing the
63 // layout and reshaping the tensor.
64 zdnn_init_pre_transformed_desc(
65 ZDNN_NHWC,
66 ggml_zdnn_type_mapping(tensor->type),
67 &buffer->pre_tfm_desc,
68 tensor->ne[3], tensor->ne[2], tensor->ne[1], tensor->ne[0]
69 );
70
71 // TODO: Consider adding a ggml check.
72 // TODO: If tensor = 4D, use ZDNN_NCHW by default.
73 // TODO: If tensor = 2D, use ZDNN_NHWC by default.
74 } break;
75 }
76
77 ZDNN_CHECK(zdnn_generate_transformed_desc(&buffer->pre_tfm_desc, &buffer->tfm_desc));
78 ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&buffer->pre_tfm_desc, &buffer->tfm_desc, &buffer->ztensor));
79}