1#include "backend-dispatched.h"
2#include "backend-virgl-apir.h"
3#include "ggml-backend-impl.h"
4#include "ggml-backend.h"
5#include "ggml-impl.h"
6
7#include <cstdint>
8
9uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
10 GGML_UNUSED(ctx);
11 GGML_UNUSED(ctx);
12 GGML_UNUSED(dec);
13
14 int32_t dev_count = reg->iface.get_device_count(reg);
15 apir_encode_int32_t(enc, &dev_count);
16
17 return 0;
18}
19
20uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
21 GGML_UNUSED(ctx);
22 GGML_UNUSED(ctx);
23 GGML_UNUSED(dec);
24
25 int32_t dev_count = reg->iface.get_device_count(reg);
26 apir_encode_int32_t(enc, &dev_count);
27
28 return 0;
29}
30
31uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
32 GGML_UNUSED(ctx);
33 GGML_UNUSED(dec);
34
35 const char * string = dev->iface.get_name(dev);
36
37 const size_t string_size = strlen(string) + 1;
38 apir_encode_array_size(enc, string_size);
39 apir_encode_char_array(enc, string, string_size);
40
41 return 0;
42}
43
44uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
45 GGML_UNUSED(ctx);
46 GGML_UNUSED(dec);
47
48 const char * string = dev->iface.get_description(dev);
49
50 const size_t string_size = strlen(string) + 1;
51 apir_encode_array_size(enc, string_size);
52 apir_encode_char_array(enc, string, string_size);
53
54 return 0;
55}
56
57uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
58 GGML_UNUSED(ctx);
59 GGML_UNUSED(dec);
60
61 uint32_t type = dev->iface.get_type(dev);
62 apir_encode_uint32_t(enc, &type);
63
64 return 0;
65}
66
67uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
68 GGML_UNUSED(ctx);
69 GGML_UNUSED(dec);
70
71 size_t free, total;
72 dev->iface.get_memory(dev, &free, &total);
73
74 apir_encode_size_t(enc, &free);
75 apir_encode_size_t(enc, &total);
76
77 return 0;
78}
79
80uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
81 GGML_UNUSED(ctx);
82
83 const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
84
85 bool supports_op = dev->iface.supports_op(dev, op);
86
87 apir_encode_bool_t(enc, &supports_op);
88
89 return 0;
90}
91
92uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
93 GGML_UNUSED(ctx);
94 GGML_UNUSED(dec);
95
96 ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev);
97
98 apir_encode_ggml_buffer_type(enc, bufft);
99
100 return 0;
101}
102
103uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
104 GGML_UNUSED(ctx);
105 GGML_UNUSED(dec);
106
107 ggml_backend_dev_props props;
108 dev->iface.get_props(dev, &props);
109
110 apir_encode_bool_t(enc, &props.caps.async);
111 apir_encode_bool_t(enc, &props.caps.host_buffer);
112 apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr);
113 apir_encode_bool_t(enc, &props.caps.events);
114
115 return 0;
116}
117
118uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
119 GGML_UNUSED(ctx);
120 GGML_UNUSED(dec);
121
122 uint32_t shmem_res_id;
123 apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
124
125 void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
126 if (!shmem_ptr) {
127 GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
128 apir_decoder_set_fatal(dec);
129 return 1;
130 }
131
132 size_t size;
133 apir_decode_size_t(dec, &size);
134 size_t max_tensor_size;
135 apir_decode_size_t(dec, &max_tensor_size);
136
137 ggml_backend_buffer_t buffer;
138 buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size);
139
140 apir_encode_ggml_buffer(enc, buffer);
141 apir_encode_ggml_buffer_type(enc, buffer->buft);
142
143 if (buffer) {
144 apir_track_backend_buffer(buffer);
145 }
146
147 return 0;
148}