1#include "ggml-backend-impl.h"
  2#include "ggml-impl.h"
  3#include "shared/apir_cs_rpc.h"
  4
  5#include <cinttypes>
  6#include <unordered_map>
  7#include <unordered_set>
  8#include <vector>
  9
 10std::unordered_set<ggml_backend_buffer_t> backend_buffers;
 11
 12void apir_track_backend_buffer(ggml_backend_buffer_t buffer) {
 13    backend_buffers.insert(buffer);
 14}
 15
 16bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {
 17    auto it = backend_buffers.find(buffer);
 18    if (it == backend_buffers.end()) {
 19        return false;
 20    }
 21
 22    backend_buffers.erase(it);
 23    return true;
 24}
 25
 26std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() {
 27    return backend_buffers;
 28}
 29
 30ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {
 31    ggml_tensor * result =
 32        ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
 33    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
 34        result->nb[i] = tensor->nb[i];
 35    }
 36    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
 37    if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {
 38        printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer);
 39        result->buffer = nullptr;
 40    }
 41
 42    uint64_t tensor_data = tensor->data;
 43    if (result->buffer) {
 44        // require that the tensor data does not go beyond the buffer end
 45        uint64_t tensor_size  = (uint64_t) ggml_nbytes(result);
 46        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
 47        uint64_t buffer_size  = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
 48
 49        // tensor->data is serialized as an offset to the buffer base address
 50        tensor_data += buffer_start;
 51
 52        GGML_ASSERT(tensor_data + tensor_size >= tensor_data);  // check for overflow
 53        GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);
 54    }
 55
 56    result->op = (ggml_op) tensor->op;
 57    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
 58        result->op_params[i] = tensor->op_params[i];
 59    }
 60    result->flags = tensor->flags;
 61    result->data  = reinterpret_cast<void *>(tensor_data);
 62    ggml_set_name(result, tensor->name);
 63    return result;
 64}
 65
 66ggml_tensor * apir_create_node(uint64_t                                                      id,
 67                               ggml_context *                                                ctx,
 68                               const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
 69                               std::unordered_map<uint64_t, ggml_tensor *> &                 tensor_map) {
 70    if (id == 0) {
 71        return nullptr;
 72    }
 73    if (tensor_map.find(id) != tensor_map.end()) {
 74        return tensor_map[id];
 75    }
 76    const apir_rpc_tensor * tensor = tensor_ptrs.at(id);
 77    ggml_tensor *           result = apir_deserialize_tensor(ctx, tensor);
 78    if (result == nullptr) {
 79        return nullptr;
 80    }
 81    tensor_map[id] = result;
 82    for (int i = 0; i < GGML_MAX_SRC; i++) {
 83        result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
 84    }
 85    result->view_src  = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
 86    result->view_offs = tensor->view_offs;
 87    return result;
 88}
 89
 90ggml_cgraph * apir_deserialize_graph(uint32_t                n_nodes,
 91                                     uint32_t                n_tensors,
 92                                     const apir_rpc_tensor * tensors,
 93                                     const uint64_t *        nodes) {
 94    size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
 95    ggml_init_params params = {
 96        /*.mem_size   =*/buf_size,
 97        /*.mem_buffer =*/NULL,
 98        /*.no_alloc   =*/true,
 99    };
100    ggml_context * ctx   = ggml_init(params);
101    ggml_cgraph *  graph = ggml_new_graph_custom(ctx, n_nodes, false);
102    graph->n_nodes       = n_nodes;
103    std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs;
104    for (uint32_t i = 0; i < n_tensors; i++) {
105        tensor_ptrs[tensors[i].id] = &tensors[i];
106    }
107    std::unordered_map<uint64_t, ggml_tensor *> tensor_map;
108    for (uint32_t i = 0; i < n_nodes; i++) {
109        int64_t id;
110        memcpy(&id, &nodes[i], sizeof(id));
111        graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);
112    }
113
114    return graph;
115}