1#include "ggml-backend-impl.h"
2#include "ggml-impl.h"
3#include "shared/apir_cs_rpc.h"
4
5#include <cinttypes>
6#include <unordered_map>
7#include <unordered_set>
8#include <vector>
9
10std::unordered_set<ggml_backend_buffer_t> backend_buffers;
11
12void apir_track_backend_buffer(ggml_backend_buffer_t buffer) {
13 backend_buffers.insert(buffer);
14}
15
16bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {
17 auto it = backend_buffers.find(buffer);
18 if (it == backend_buffers.end()) {
19 return false;
20 }
21
22 backend_buffers.erase(it);
23 return true;
24}
25
26std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() {
27 return backend_buffers;
28}
29
30ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {
31 ggml_tensor * result =
32 ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
33 for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
34 result->nb[i] = tensor->nb[i];
35 }
36 result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
37 if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {
38 printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer);
39 result->buffer = nullptr;
40 }
41
42 uint64_t tensor_data = tensor->data;
43 if (result->buffer) {
44 // require that the tensor data does not go beyond the buffer end
45 uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
46 uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
47 uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
48
49 // tensor->data is serialized as an offset to the buffer base address
50 tensor_data += buffer_start;
51
52 GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow
53 GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);
54 }
55
56 result->op = (ggml_op) tensor->op;
57 for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
58 result->op_params[i] = tensor->op_params[i];
59 }
60 result->flags = tensor->flags;
61 result->data = reinterpret_cast<void *>(tensor_data);
62 ggml_set_name(result, tensor->name);
63 return result;
64}
65
66ggml_tensor * apir_create_node(uint64_t id,
67 ggml_context * ctx,
68 const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
69 std::unordered_map<uint64_t, ggml_tensor *> & tensor_map) {
70 if (id == 0) {
71 return nullptr;
72 }
73 if (tensor_map.find(id) != tensor_map.end()) {
74 return tensor_map[id];
75 }
76 const apir_rpc_tensor * tensor = tensor_ptrs.at(id);
77 ggml_tensor * result = apir_deserialize_tensor(ctx, tensor);
78 if (result == nullptr) {
79 return nullptr;
80 }
81 tensor_map[id] = result;
82 for (int i = 0; i < GGML_MAX_SRC; i++) {
83 result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
84 }
85 result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
86 result->view_offs = tensor->view_offs;
87 return result;
88}
89
90ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes,
91 uint32_t n_tensors,
92 const apir_rpc_tensor * tensors,
93 const uint64_t * nodes) {
94 size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
95 ggml_init_params params = {
96 /*.mem_size =*/buf_size,
97 /*.mem_buffer =*/NULL,
98 /*.no_alloc =*/true,
99 };
100 ggml_context * ctx = ggml_init(params);
101 ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
102 graph->n_nodes = n_nodes;
103 std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs;
104 for (uint32_t i = 0; i < n_tensors; i++) {
105 tensor_ptrs[tensors[i].id] = &tensors[i];
106 }
107 std::unordered_map<uint64_t, ggml_tensor *> tensor_map;
108 for (uint32_t i = 0; i < n_nodes; i++) {
109 int64_t id;
110 memcpy(&id, &nodes[i], sizeof(id));
111 graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);
112 }
113
114 return graph;
115}