1#include "ggml-remoting.h"
  2#include "ggml-virtgpu.h"
  3
  4#include <iostream>
  5#include <mutex>
  6
  7void ggml_virtgpu_cleanup(virtgpu * gpu);
  8
  9static virtgpu * apir_initialize() {
 10    static virtgpu *         gpu          = NULL;
 11    static std::atomic<bool> initialized  = false;
 12
 13    if (initialized) {
 14        // fast track
 15        return gpu;
 16    }
 17
 18    {
 19        static std::mutex           mutex;
 20        std::lock_guard<std::mutex> lock(mutex);
 21
 22        if (initialized) {
 23            // thread safe
 24            return gpu;
 25        }
 26
 27        gpu = create_virtgpu();
 28        if (!gpu) {
 29            initialized = true;
 30            return NULL;
 31        }
 32
 33        // Pre-fetch and cache all device information, it will not change
 34        gpu->cached_device_info.description  = apir_device_get_description(gpu);
 35        if (!gpu->cached_device_info.description) {
 36            GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
 37        }
 38        gpu->cached_device_info.name         = apir_device_get_name(gpu);
 39        if (!gpu->cached_device_info.name) {
 40            GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__);
 41        }
 42        gpu->cached_device_info.device_count = apir_device_get_count(gpu);
 43        gpu->cached_device_info.type         = apir_device_get_type(gpu);
 44
 45        apir_device_get_memory(gpu,
 46                              &gpu->cached_device_info.memory_free,
 47                              &gpu->cached_device_info.memory_total);
 48
 49        apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
 50        gpu->cached_buffer_type.host_handle             = buft_host_handle;
 51        gpu->cached_buffer_type.name                    = apir_buffer_type_get_name(gpu, buft_host_handle);
 52        if (!gpu->cached_buffer_type.name) {
 53            GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__);
 54        }
 55        gpu->cached_buffer_type.alignment               = apir_buffer_type_get_alignment(gpu, buft_host_handle);
 56        gpu->cached_buffer_type.max_size                = apir_buffer_type_get_max_size(gpu, buft_host_handle);
 57
 58        initialized = true;
 59    }
 60
 61    return gpu;
 62}
 63
 64static int ggml_backend_remoting_get_device_count() {
 65    virtgpu * gpu = apir_initialize();
 66    if (!gpu) {
 67        return 0;
 68    }
 69
 70    return gpu->cached_device_info.device_count;
 71}
 72
 73static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) {
 74    UNUSED(reg);
 75
 76    return ggml_backend_remoting_get_device_count();
 77}
 78
 79static std::vector<ggml_backend_dev_t> devices;
 80
 81ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) {
 82    GGML_ASSERT(device < devices.size());
 83    return devices[device];
 84}
 85
 86static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
 87    if (devices.size() > 0) {
 88        GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__);
 89        return;
 90    }
 91
 92    virtgpu * gpu = apir_initialize();
 93    if (!gpu) {
 94        GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__);
 95        return;
 96    }
 97
 98    static std::atomic<bool> initialized = false;
 99
100    if (initialized) {
101        return; // fast track
102    }
103
104    {
105        static std::mutex           mutex;
106        std::lock_guard<std::mutex> lock(mutex);
107        if (!initialized) {
108            for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) {
109                ggml_backend_remoting_device_context * ctx       = new ggml_backend_remoting_device_context;
110                char                                   desc[256] = "ggml-virtgpu API Remoting device";
111
112                ctx->device      = i;
113                ctx->name        = GGML_VIRTGPU_NAME + std::to_string(i);
114                ctx->description = desc;
115                ctx->gpu         = gpu;
116
117                ggml_backend_dev_t dev = new ggml_backend_device{
118                    /* .iface   = */ ggml_backend_remoting_device_interface,
119                    /* .reg     = */ reg,
120                    /* .context = */ ctx,
121                };
122                devices.push_back(dev);
123            }
124            initialized = true;
125        }
126    }
127}
128
129static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) {
130    UNUSED(reg);
131
132    return ggml_backend_remoting_get_device(device);
133}
134
135static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) {
136    UNUSED(reg);
137
138    return GGML_VIRTGPU_NAME;
139}
140
141static const ggml_backend_reg_i ggml_backend_remoting_reg_i = {
142    /* .get_name         = */ ggml_backend_remoting_reg_get_name,
143    /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count,
144    /* .get_device       = */ ggml_backend_remoting_reg_get_device,
145    /* .get_proc_address = */ NULL,
146};
147
148ggml_backend_reg_t ggml_backend_virtgpu_reg() {
149    virtgpu * gpu = apir_initialize();
150    if (!gpu) {
151        GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__);
152    }
153
154    static ggml_backend_reg reg = {
155        /* .api_version = */ GGML_BACKEND_API_VERSION,
156        /* .iface       = */ ggml_backend_remoting_reg_i,
157        /* .context     = */ gpu,
158    };
159
160    static bool initialized = false;
161    if (initialized) {
162        return &reg;
163    }
164    initialized = true;
165
166    ggml_backend_remoting_reg_init_devices(&reg);
167
168    return &reg;
169}
170
171// public function, not exposed in the GGML interface at the moment
172void ggml_virtgpu_cleanup(virtgpu * gpu) {
173    if (gpu->cached_device_info.name) {
174        free(gpu->cached_device_info.name);
175        gpu->cached_device_info.name = NULL;
176    }
177    if (gpu->cached_device_info.description) {
178        free(gpu->cached_device_info.description);
179        gpu->cached_device_info.description = NULL;
180    }
181    if (gpu->cached_buffer_type.name) {
182        free(gpu->cached_buffer_type.name);
183        gpu->cached_buffer_type.name = NULL;
184    }
185
186    mtx_destroy(&gpu->data_shmem_mutex);
187}
188
189GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg)