1#pragma once
2
3#include "ggml-impl.h"
4
5#include <cassert>
6#include <cstring>
7
8#define likely(x) __builtin_expect(!!(x), 1)
9#define unlikely(x) __builtin_expect(!!(x), 0)
10
11struct apir_encoder {
12 char * cur;
13 const char * start;
14 const char * end;
15 bool fatal;
16
17};
18
19struct apir_decoder {
20 const char * cur;
21 const char * end;
22 bool fatal;
23};
24
25/*
26 * new encoder and decoder
27 */
28
29static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
30 apir_decoder dec = {
31 .cur = ptr,
32 .end = ptr + size,
33 .fatal = false,
34 };
35
36 return dec;
37}
38
39static apir_encoder apir_new_encoder(char * ptr, size_t size) {
40 apir_encoder enc = {
41 .cur = ptr,
42 .start = ptr,
43 .end = ptr + size,
44 .fatal = false,
45 };
46
47 return enc;
48}
49
50/*
51 * fatal flag handling
52 */
53
54static inline void apir_encoder_reset_fatal(apir_encoder * enc) {
55 enc->fatal = false;
56}
57
58static inline void apir_encoder_set_fatal(apir_encoder * enc) {
59 enc->fatal = true;
60}
61
62static inline bool apir_encoder_get_fatal(const apir_encoder * enc) {
63 return enc->fatal;
64}
65
66static inline void apir_decoder_reset_fatal(apir_decoder * dec) {
67 dec->fatal = false;
68}
69
70static inline void apir_decoder_set_fatal(apir_decoder * dec) {
71 dec->fatal = true;
72}
73
74static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
75 return dec->fatal;
76}
77
78/*
79 * encode peek
80 */
81
82static inline bool apir_decoder_peek_internal(apir_decoder * dec,
83 size_t size,
84 void * val,
85 size_t val_size) {
86 assert(val_size <= size);
87
88 if (unlikely(size > (size_t) (dec->end - dec->cur))) {
89 GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
90 apir_decoder_set_fatal(dec);
91 memset(val, 0, val_size);
92 return false;
93 }
94
95 /* we should not rely on the compiler to optimize away memcpy... */
96 memcpy(val, dec->cur, val_size);
97 return true;
98}
99
100static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {
101 apir_decoder_peek_internal(dec, size, val, val_size);
102}
103
104static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {
105 if (unlikely(size > (size_t) (dec->end - dec->cur))) {
106 GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
107 apir_decoder_set_fatal(dec);
108 return NULL;
109 }
110 const void * addr = dec->cur;
111 dec->cur += size;
112
113 return addr;
114}
115
116/*
117 * read/write
118 */
119
120static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {
121 if (apir_decoder_peek_internal(dec, size, val, val_size)) {
122 dec->cur += size;
123 }
124}
125
126static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {
127 assert(val_size <= size);
128 assert(size <= ((size_t) (enc->end - enc->cur)));
129
130 char * write_addr = enc->cur;
131 /* we should not rely on the compiler to optimize away memcpy... */
132 memcpy(write_addr, val, val_size);
133 enc->cur += size;
134
135 return write_addr;
136}
137
138/*
139 * encode/decode
140 */
141
142static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {
143 assert(size % 4 == 0);
144 apir_decoder_read(dec, size, data, data_size);
145}
146
147static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {
148 assert(size % 4 == 0);
149 apir_encoder_write(enc, size, data, data_size);
150}
151
152/*
153 * typed encode/decode
154 */
155
156/* uint8_t */
157
158static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {
159 apir_encode(enc, sizeof(int), val, sizeof(*val));
160}
161
162static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {
163 apir_decode(dec, sizeof(int), val, sizeof(*val));
164}
165
166/* uint64_t */
167
168static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {
169 apir_encode(enc, 8, val, sizeof(*val));
170}
171
172static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {
173 apir_decode(dec, 8, val, sizeof(*val));
174}
175
176static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {
177 const size_t size = sizeof(*val) * count;
178 assert(size >= count);
179 apir_encode(enc, size, val, size);
180}
181
182static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {
183 const size_t size = sizeof(*val) * count;
184 assert(size >= count);
185 apir_decode(dec, size, val, size);
186}
187
188static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {
189 return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));
190}
191
192/* int32_t */
193
194static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {
195 apir_encode(enc, 4, val, sizeof(*val));
196}
197
198static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {
199 apir_decode(dec, 4, val, sizeof(*val));
200}
201
202static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {
203 const size_t size = sizeof(*val) * count;
204 assert(size >= count);
205 apir_encode(enc, size, val, size);
206}
207
208static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {
209 const size_t size = sizeof(*val) * count;
210 assert(size >= count);
211 apir_decode(dec, size, val, size);
212}
213
214/* array size (uint64_t) */
215
216static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {
217 apir_encode_uint64_t(enc, &size);
218}
219
220static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {
221 uint64_t size;
222 apir_decode_uint64_t(dec, &size);
223 if (size != expected_size) {
224 GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__);
225 apir_decoder_set_fatal(dec);
226 size = 0;
227 }
228 return size;
229}
230
231static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {
232 uint64_t size;
233 apir_decode_uint64_t(dec, &size);
234 return size;
235}
236
237/* non-array pointer */
238
239static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {
240 apir_encode_array_size(enc, val ? 1 : 0);
241 return val;
242}
243
244static inline bool apir_decode_simple_pointer(apir_decoder * dec) {
245 return apir_decode_array_size_unchecked(dec);
246}
247
248/* uint32_t */
249
250static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {
251 apir_encode(enc, 4, val, sizeof(*val));
252}
253
254static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {
255 apir_decode(dec, 4, val, sizeof(*val));
256}
257
258static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {
259 const size_t size = sizeof(*val) * count;
260 assert(size >= count);
261 apir_encode(enc, size, val, size);
262}
263
264static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {
265 const size_t size = sizeof(*val) * count;
266 assert(size >= count);
267 apir_decode(dec, size, val, size);
268}
269
270/* size_t */
271
272static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {
273 const uint64_t tmp = *val;
274 apir_encode_uint64_t(enc, &tmp);
275}
276
277static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {
278 uint64_t tmp;
279 apir_decode_uint64_t(dec, &tmp);
280 *val = tmp;
281}
282
283static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {
284 if (sizeof(size_t) == sizeof(uint64_t)) {
285 apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);
286 } else {
287 for (uint32_t i = 0; i < count; i++) {
288 apir_encode_size_t(enc, &val[i]);
289 }
290 }
291}
292
293static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {
294 if (sizeof(size_t) == sizeof(uint64_t)) {
295 apir_decode_uint64_t_array(dec, (uint64_t *) val, count);
296 } else {
297 for (uint32_t i = 0; i < count; i++) {
298 apir_decode_size_t(dec, &val[i]);
299 }
300 }
301}
302
303/* opaque blob */
304
305static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {
306 apir_encode(enc, (size + 3) & ~3, val, size);
307}
308
309static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {
310 apir_decode(dec, (size + 3) & ~3, val, size);
311}
312
313/* string */
314
315static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {
316 assert(size && strlen(val) < size);
317 apir_encode_blob_array(enc, val, size);
318}
319
320static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {
321 apir_decode_blob_array(dec, val, size);
322 if (size) {
323 val[size - 1] = '\0';
324 } else {
325 GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__);
326 apir_decoder_set_fatal(dec);
327 }
328}
329
330/* (temp) buffer allocation */
331
332static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
333 size_t alloc_size;
334 if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
335 GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n",
336 __func__, size, count);
337 return NULL;
338 }
339
340 return malloc(alloc_size);
341}
342
343/* bool */
344
345static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {
346 apir_encode(enc, sizeof(int), val, sizeof(bool));
347}
348
349static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
350 apir_decode(dec, sizeof(int), val, sizeof(bool));
351}
352
353/* apir_buffer_type_host_handle_t */
354
355static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
356 const apir_buffer_type_host_handle_t * val) {
357 apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
358}
359
360static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
361 apir_buffer_type_host_handle_t * val) {
362 apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
363}
364
365/* apir_buffer_host_handle_t */
366
367static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
368 const apir_buffer_host_handle_t * val) {
369 apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
370}
371
372static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {
373 apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
374}
375
376/* uintptr_t */
377
378static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {
379 apir_encode(enc, sizeof(*val), val, sizeof(*val));
380}
381
382static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {
383 apir_decode(dec, sizeof(*val), val, sizeof(*val));
384}