1#include "ggml-remoting.h"
 2
 3static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
 4                                                                            size_t                     size) {
 5    virtgpu * gpu = BUFT_TO_GPU(buft);
 6
 7    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
 8    if (!context) {
 9        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__);
10    }
11
12    context->gpu = gpu;
13
14    bool async__unused, host_buffer__unused, events__unused;
15    bool buffer_from_host_ptr;
16    apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused);
17
18    if (buffer_from_host_ptr) {
19        context->apir_context = apir_device_buffer_from_ptr(gpu, size, size);
20        context->base         = context->apir_context.shmem.mmap_ptr;
21        context->is_from_ptr  = true;
22    } else {
23        context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size);
24        context->is_from_ptr  = false;
25        context->base         = NULL;
26    }
27
28    ggml_backend_buffer_t buffer =
29        ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size);
30
31    return buffer;
32}
33
34static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
35    virtgpu * gpu = BUFT_TO_GPU(buft);
36
37    return gpu->cached_buffer_type.name;
38}
39
40static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
41    virtgpu * gpu = BUFT_TO_GPU(buft);
42
43    return gpu->cached_buffer_type.alignment;
44}
45
46static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
47    virtgpu * gpu = BUFT_TO_GPU(buft);
48
49    return gpu->cached_buffer_type.max_size;
50}
51
52static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
53                                                               const ggml_tensor *        tensor) {
54    virtgpu * gpu = BUFT_TO_GPU(buft);
55
56    if (tensor->buffer == NULL
57        || !tensor->buffer->context
58        || !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
59        return ggml_nbytes(tensor);
60    }
61
62    return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor);
63}
64
65const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = {
66    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,
67    /* .alloc_buffer     = */ ggml_backend_remoting_buffer_type_alloc_buffer,
68    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,
69    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,
70    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,
71    /* .is_host          = */ NULL,
72};
73
74const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = {
75    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,
76    /* .alloc_buffer     = */ NULL,
77    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,
78    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,
79    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,
80    /* .is_host          = */ NULL,
81};