1#include "ggml-remoting.h"
2#include "ggml-virtgpu.h"
3
4#include <iostream>
5#include <mutex>
6
7void ggml_virtgpu_cleanup(virtgpu * gpu);
8
9static virtgpu * apir_initialize() {
10 static virtgpu * gpu = NULL;
11 static std::atomic<bool> initialized = false;
12
13 if (initialized) {
14 // fast track
15 return gpu;
16 }
17
18 {
19 static std::mutex mutex;
20 std::lock_guard<std::mutex> lock(mutex);
21
22 if (initialized) {
23 // thread safe
24 return gpu;
25 }
26
27 gpu = create_virtgpu();
28 if (!gpu) {
29 initialized = true;
30 return NULL;
31 }
32
33 // Pre-fetch and cache all device information, it will not change
34 gpu->cached_device_info.description = apir_device_get_description(gpu);
35 if (!gpu->cached_device_info.description) {
36 GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
37 }
38 gpu->cached_device_info.name = apir_device_get_name(gpu);
39 if (!gpu->cached_device_info.name) {
40 GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__);
41 }
42 gpu->cached_device_info.device_count = apir_device_get_count(gpu);
43 gpu->cached_device_info.type = apir_device_get_type(gpu);
44
45 apir_device_get_memory(gpu,
46 &gpu->cached_device_info.memory_free,
47 &gpu->cached_device_info.memory_total);
48
49 apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
50 gpu->cached_buffer_type.host_handle = buft_host_handle;
51 gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle);
52 if (!gpu->cached_buffer_type.name) {
53 GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__);
54 }
55 gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
56 gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
57
58 initialized = true;
59 }
60
61 return gpu;
62}
63
64static int ggml_backend_remoting_get_device_count() {
65 virtgpu * gpu = apir_initialize();
66 if (!gpu) {
67 return 0;
68 }
69
70 return gpu->cached_device_info.device_count;
71}
72
73static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) {
74 UNUSED(reg);
75
76 return ggml_backend_remoting_get_device_count();
77}
78
79static std::vector<ggml_backend_dev_t> devices;
80
81ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) {
82 GGML_ASSERT(device < devices.size());
83 return devices[device];
84}
85
86static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
87 if (devices.size() > 0) {
88 GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__);
89 return;
90 }
91
92 virtgpu * gpu = apir_initialize();
93 if (!gpu) {
94 GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__);
95 return;
96 }
97
98 static std::atomic<bool> initialized = false;
99
100 if (initialized) {
101 return; // fast track
102 }
103
104 {
105 static std::mutex mutex;
106 std::lock_guard<std::mutex> lock(mutex);
107 if (!initialized) {
108 for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) {
109 ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context;
110 char desc[256] = "ggml-virtgpu API Remoting device";
111
112 ctx->device = i;
113 ctx->name = GGML_VIRTGPU_NAME + std::to_string(i);
114 ctx->description = desc;
115 ctx->gpu = gpu;
116
117 ggml_backend_dev_t dev = new ggml_backend_device{
118 /* .iface = */ ggml_backend_remoting_device_interface,
119 /* .reg = */ reg,
120 /* .context = */ ctx,
121 };
122 devices.push_back(dev);
123 }
124 initialized = true;
125 }
126 }
127}
128
129static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) {
130 UNUSED(reg);
131
132 return ggml_backend_remoting_get_device(device);
133}
134
135static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) {
136 UNUSED(reg);
137
138 return GGML_VIRTGPU_NAME;
139}
140
141static const ggml_backend_reg_i ggml_backend_remoting_reg_i = {
142 /* .get_name = */ ggml_backend_remoting_reg_get_name,
143 /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count,
144 /* .get_device = */ ggml_backend_remoting_reg_get_device,
145 /* .get_proc_address = */ NULL,
146};
147
148ggml_backend_reg_t ggml_backend_virtgpu_reg() {
149 virtgpu * gpu = apir_initialize();
150 if (!gpu) {
151 GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__);
152 }
153
154 static ggml_backend_reg reg = {
155 /* .api_version = */ GGML_BACKEND_API_VERSION,
156 /* .iface = */ ggml_backend_remoting_reg_i,
157 /* .context = */ gpu,
158 };
159
160 static bool initialized = false;
161 if (initialized) {
162 return ®
163 }
164 initialized = true;
165
166 ggml_backend_remoting_reg_init_devices(®);
167
168 return ®
169}
170
171// public function, not exposed in the GGML interface at the moment
172void ggml_virtgpu_cleanup(virtgpu * gpu) {
173 if (gpu->cached_device_info.name) {
174 free(gpu->cached_device_info.name);
175 gpu->cached_device_info.name = NULL;
176 }
177 if (gpu->cached_device_info.description) {
178 free(gpu->cached_device_info.description);
179 gpu->cached_device_info.description = NULL;
180 }
181 if (gpu->cached_buffer_type.name) {
182 free(gpu->cached_buffer_type.name);
183 gpu->cached_buffer_type.name = NULL;
184 }
185
186 mtx_destroy(&gpu->data_shmem_mutex);
187}
188
189GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg)