1#ifndef GGML_ZDNN_COMMON_HPP
 2#define GGML_ZDNN_COMMON_HPP
 3
 4#include "ggml.h"
 5#include "ggml-impl.h"
 6
 7#include "zdnn.h"
 8
 9#include <vector>
10#include <memory>
11
12#define GGML_ZDNN_NAME    "zDNN"
13#define GGML_ZDNN_VERSION ZDNN_VERNUM
14
15#define ZDNN_CHECK(stmt)                \
16    do {                                \
17        zdnn_status status = (stmt);    \
18        GGML_ASSERT(status == ZDNN_OK); \
19    } while (0);
20
21struct ggml_backend_zdnn_device_context {
22    int zdnn_device;
23    int zdnn_device_ref_count;
24
25    bool has_parmblkformat_0;
26    bool has_parmblkformat_1;  // checks for z17
27
28    size_t max_size;
29
30    char name[128];
31};
32
33struct ggml_backend_zdnn_context {
34    int device;
35    ggml_cgraph * gf;
36};
37
38struct ggml_backend_zdnn_buffer {
39    void * data;
40    ggml_backend_zdnn_buffer * extra;  // for bias, etc.
41    size_t size;
42
43    zdnn_tensor_desc pre_tfm_desc;
44    zdnn_tensor_desc tfm_desc;
45    zdnn_ztensor     ztensor;
46
47    char name[GGML_MAX_NAME];
48};
49
50struct ggml_backend_zdnn_buffer_context {
51    void * all_data;
52    size_t all_size;
53    bool owned;
54
55    int n_buffers;
56    std::vector<std::unique_ptr<ggml_backend_zdnn_buffer>> buffers;
57};
58
59#endif  // GGML_ZDNN_COMMON_HPP