1#include "ggml-remoting.h"
  2
  3static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
  4    virtgpu * gpu = DEV_TO_GPU(dev);
  5
  6    return gpu->cached_device_info.name;
  7}
  8
  9static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) {
 10    virtgpu * gpu = DEV_TO_GPU(dev);
 11
 12    // Return the pre-cached description from the virtgpu structure
 13    return gpu->cached_device_info.description;
 14}
 15
 16static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) {
 17    virtgpu * gpu = DEV_TO_GPU(dev);
 18
 19    return (enum ggml_backend_dev_type) gpu->cached_device_info.type;
 20}
 21
 22static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
 23    virtgpu * gpu = DEV_TO_GPU(dev);
 24
 25    *free = gpu->cached_device_info.memory_free;
 26    *total = gpu->cached_device_info.memory_total;
 27}
 28
 29static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
 30#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1
 31    /* ggml-rpc cheats it like this */
 32    /* with the current implementation of serialize_tensor, the src/view aren't properly passed */
 33    UNUSED(dev);
 34    UNUSED(op);
 35
 36    return true;
 37#else
 38    virtgpu * gpu = DEV_TO_GPU(dev);
 39
 40    return apir_device_supports_op(gpu, op);
 41#endif
 42}
 43
 44static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
 45    bool supported = buft->device == dev;
 46
 47    return supported;
 48}
 49
 50static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
 51    UNUSED(dev);
 52    UNUSED(op);
 53
 54    return false;
 55}
 56
 57static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
 58    props->name        = ggml_backend_remoting_device_get_name(dev);
 59    props->description = ggml_backend_remoting_device_get_description(dev);
 60    props->type        = ggml_backend_remoting_device_get_type(dev);
 61    ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total);
 62
 63    virtgpu * gpu = DEV_TO_GPU(dev);
 64    apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr,
 65                          &props->caps.events);
 66
 67    props->caps.buffer_from_host_ptr = false;
 68    props->caps.async                = false;
 69    props->caps.events               = false;
 70}
 71
 72ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
 73    virtgpu * gpu = DEV_TO_GPU(dev);
 74
 75    static std::atomic<bool> initialized = false;
 76    static ggml_backend_buffer_type buft;
 77
 78    if (!initialized) {
 79        static std::mutex           mutex;
 80        std::lock_guard<std::mutex> lock(mutex);
 81
 82        if (!initialized) {
 83            buft = {
 84                /* .iface    = */ ggml_backend_remoting_buffer_type_interface,
 85                /* .device   = */ dev,
 86                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,
 87            };
 88            initialized = true;
 89        }
 90    }
 91
 92    return &buft;
 93}
 94
 95static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
 96    virtgpu * gpu = DEV_TO_GPU(dev);
 97
 98    static std::atomic<bool> initialized = false;
 99    static ggml_backend_buffer_type buft;
100
101    if (!initialized) {
102        static std::mutex           mutex;
103        std::lock_guard<std::mutex> lock(mutex);
104
105        if (!initialized) {
106            buft = {
107                /* .iface    = */ ggml_backend_remoting_buffer_from_ptr_type_interface,
108                /* .device   = */ dev,
109                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,
110            };
111            initialized = true;
112        }
113    }
114
115    return &buft;
116}
117
118static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev,
119                                                                          void *             ptr,
120                                                                          size_t             size,
121                                                                          size_t             max_tensor_size) {
122    virtgpu * gpu = DEV_TO_GPU(dev);
123
124    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
125    if (!context) {
126        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__);
127    }
128
129    context->gpu          = gpu;
130    context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size);
131    context->base         = ptr;
132    context->is_from_ptr  = true;
133
134    ggml_backend_buffer_t buffer =
135        ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev),
136                                 ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size);
137
138    return buffer;
139}
140
141const ggml_backend_device_i ggml_backend_remoting_device_interface = {
142    /* .get_name             = */ ggml_backend_remoting_device_get_name,
143    /* .get_description      = */ ggml_backend_remoting_device_get_description,
144    /* .get_memory           = */ ggml_backend_remoting_device_get_memory,
145    /* .get_type             = */ ggml_backend_remoting_device_get_type,
146    /* .get_props            = */ ggml_backend_remoting_device_get_props,
147    /* .init_backend         = */ ggml_backend_remoting_device_init,
148    /* .get_buffer_type      = */ ggml_backend_remoting_device_get_buffer_type,
149    /* .get_host_buffer_type = */ NULL,
150    /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr,
151    /* .supports_op          = */ ggml_backend_remoting_device_supports_op,
152    /* .supports_buft        = */ ggml_backend_remoting_device_supports_buft,
153    /* .offload_op           = */ ggml_backend_remoting_device_offload_op,
154    /* .event_new            = */ NULL,
155    /* .event_free           = */ NULL,
156    /* .event_synchronize    = */ NULL,
157};