1#ifndef GGML_ZDNN_COMMON_HPP
2#define GGML_ZDNN_COMMON_HPP
3
4#include "ggml.h"
5#include "ggml-impl.h"
6
7#include "zdnn.h"
8
9#include <vector>
10#include <memory>
11
12#define GGML_ZDNN_NAME "zDNN"
13#define GGML_ZDNN_VERSION ZDNN_VERNUM
14
15#define ZDNN_CHECK(stmt) \
16 do { \
17 zdnn_status status = (stmt); \
18 GGML_ASSERT(status == ZDNN_OK); \
19 } while (0);
20
21struct ggml_backend_zdnn_device_context {
22 int zdnn_device;
23 int zdnn_device_ref_count;
24
25 bool has_parmblkformat_0;
26 bool has_parmblkformat_1; // checks for z17
27
28 size_t max_size;
29
30 char name[128];
31};
32
33struct ggml_backend_zdnn_context {
34 int device;
35 ggml_cgraph * gf;
36};
37
38struct ggml_backend_zdnn_buffer {
39 void * data;
40 ggml_backend_zdnn_buffer * extra; // for bias, etc.
41 size_t size;
42
43 zdnn_tensor_desc pre_tfm_desc;
44 zdnn_tensor_desc tfm_desc;
45 zdnn_ztensor ztensor;
46
47 char name[GGML_MAX_NAME];
48};
49
50struct ggml_backend_zdnn_buffer_context {
51 void * all_data;
52 size_t all_size;
53 bool owned;
54
55 int n_buffers;
56 std::vector<std::unique_ptr<ggml_backend_zdnn_buffer>> buffers;
57};
58
59#endif // GGML_ZDNN_COMMON_HPP