1#pragma once
  2
  3#include "ggml-impl.h"
  4
  5#include <cassert>
  6#include <cstring>
  7
  8#define likely(x)   __builtin_expect(!!(x), 1)
  9#define unlikely(x) __builtin_expect(!!(x), 0)
 10
 11struct apir_encoder {
 12    char *       cur;
 13    const char * start;
 14    const char * end;
 15    bool         fatal;
 16
 17};
 18
 19struct apir_decoder {
 20    const char * cur;
 21    const char * end;
 22    bool         fatal;
 23};
 24
 25/*
 26 * new encoder and decoder
 27 */
 28
 29static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
 30    apir_decoder dec = {
 31        .cur = ptr,
 32        .end = ptr + size,
 33        .fatal = false,
 34    };
 35
 36    return dec;
 37}
 38
 39static apir_encoder apir_new_encoder(char * ptr, size_t size) {
 40    apir_encoder enc = {
 41        .cur   = ptr,
 42        .start = ptr,
 43        .end   = ptr + size,
 44        .fatal = false,
 45    };
 46
 47    return enc;
 48}
 49
 50/*
 51 * fatal flag handling
 52 */
 53
 54static inline void apir_encoder_reset_fatal(apir_encoder * enc) {
 55    enc->fatal = false;
 56}
 57
 58static inline void apir_encoder_set_fatal(apir_encoder * enc) {
 59    enc->fatal = true;
 60}
 61
 62static inline bool apir_encoder_get_fatal(const apir_encoder * enc) {
 63    return enc->fatal;
 64}
 65
 66static inline void apir_decoder_reset_fatal(apir_decoder * dec) {
 67    dec->fatal = false;
 68}
 69
 70static inline void apir_decoder_set_fatal(apir_decoder * dec) {
 71    dec->fatal = true;
 72}
 73
 74static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
 75    return dec->fatal;
 76}
 77
 78/*
 79 * encode peek
 80 */
 81
 82static inline bool apir_decoder_peek_internal(apir_decoder * dec,
 83                                              size_t                size,
 84                                              void *                val,
 85                                              size_t                val_size) {
 86    assert(val_size <= size);
 87
 88    if (unlikely(size > (size_t) (dec->end - dec->cur))) {
 89        GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
 90        apir_decoder_set_fatal(dec);
 91        memset(val, 0, val_size);
 92        return false;
 93    }
 94
 95    /* we should not rely on the compiler to optimize away memcpy... */
 96    memcpy(val, dec->cur, val_size);
 97    return true;
 98}
 99
100static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {
101    apir_decoder_peek_internal(dec, size, val, val_size);
102}
103
104static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {
105    if (unlikely(size > (size_t) (dec->end - dec->cur))) {
106        GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
107        apir_decoder_set_fatal(dec);
108        return NULL;
109    }
110    const void * addr = dec->cur;
111    dec->cur += size;
112
113    return addr;
114}
115
116/*
117 * read/write
118 */
119
120static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {
121    if (apir_decoder_peek_internal(dec, size, val, val_size)) {
122        dec->cur += size;
123    }
124}
125
126static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {
127    assert(val_size <= size);
128    assert(size <= ((size_t) (enc->end - enc->cur)));
129
130    char * write_addr = enc->cur;
131    /* we should not rely on the compiler to optimize away memcpy... */
132    memcpy(write_addr, val, val_size);
133    enc->cur += size;
134
135    return write_addr;
136}
137
138/*
139 * encode/decode
140 */
141
142static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {
143    assert(size % 4 == 0);
144    apir_decoder_read(dec, size, data, data_size);
145}
146
147static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {
148    assert(size % 4 == 0);
149    apir_encoder_write(enc, size, data, data_size);
150}
151
152/*
153 * typed encode/decode
154 */
155
156/* uint8_t */
157
158static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {
159    apir_encode(enc, sizeof(int), val, sizeof(*val));
160}
161
162static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {
163    apir_decode(dec, sizeof(int), val, sizeof(*val));
164}
165
166/* uint64_t */
167
168static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {
169    apir_encode(enc, 8, val, sizeof(*val));
170}
171
172static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {
173    apir_decode(dec, 8, val, sizeof(*val));
174}
175
176static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {
177    const size_t size = sizeof(*val) * count;
178    assert(size >= count);
179    apir_encode(enc, size, val, size);
180}
181
182static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {
183    const size_t size = sizeof(*val) * count;
184    assert(size >= count);
185    apir_decode(dec, size, val, size);
186}
187
188static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {
189    return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));
190}
191
192/* int32_t */
193
194static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {
195    apir_encode(enc, 4, val, sizeof(*val));
196}
197
198static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {
199    apir_decode(dec, 4, val, sizeof(*val));
200}
201
202static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {
203    const size_t size = sizeof(*val) * count;
204    assert(size >= count);
205    apir_encode(enc, size, val, size);
206}
207
208static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {
209    const size_t size = sizeof(*val) * count;
210    assert(size >= count);
211    apir_decode(dec, size, val, size);
212}
213
214/* array size (uint64_t) */
215
216static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {
217    apir_encode_uint64_t(enc, &size);
218}
219
220static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {
221    uint64_t size;
222    apir_decode_uint64_t(dec, &size);
223    if (size != expected_size) {
224        GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__);
225        apir_decoder_set_fatal(dec);
226        size = 0;
227    }
228    return size;
229}
230
231static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {
232    uint64_t size;
233    apir_decode_uint64_t(dec, &size);
234    return size;
235}
236
237/* non-array pointer */
238
239static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {
240    apir_encode_array_size(enc, val ? 1 : 0);
241    return val;
242}
243
244static inline bool apir_decode_simple_pointer(apir_decoder * dec) {
245    return apir_decode_array_size_unchecked(dec);
246}
247
248/* uint32_t */
249
250static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {
251    apir_encode(enc, 4, val, sizeof(*val));
252}
253
254static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {
255    apir_decode(dec, 4, val, sizeof(*val));
256}
257
258static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {
259    const size_t size = sizeof(*val) * count;
260    assert(size >= count);
261    apir_encode(enc, size, val, size);
262}
263
264static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {
265    const size_t size = sizeof(*val) * count;
266    assert(size >= count);
267    apir_decode(dec, size, val, size);
268}
269
270/* size_t */
271
272static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {
273    const uint64_t tmp = *val;
274    apir_encode_uint64_t(enc, &tmp);
275}
276
277static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {
278    uint64_t tmp;
279    apir_decode_uint64_t(dec, &tmp);
280    *val = tmp;
281}
282
283static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {
284    if (sizeof(size_t) == sizeof(uint64_t)) {
285        apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);
286    } else {
287        for (uint32_t i = 0; i < count; i++) {
288            apir_encode_size_t(enc, &val[i]);
289        }
290    }
291}
292
293static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {
294    if (sizeof(size_t) == sizeof(uint64_t)) {
295        apir_decode_uint64_t_array(dec, (uint64_t *) val, count);
296    } else {
297        for (uint32_t i = 0; i < count; i++) {
298            apir_decode_size_t(dec, &val[i]);
299        }
300    }
301}
302
303/* opaque blob */
304
305static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {
306    apir_encode(enc, (size + 3) & ~3, val, size);
307}
308
309static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {
310    apir_decode(dec, (size + 3) & ~3, val, size);
311}
312
313/* string */
314
315static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {
316    assert(size && strlen(val) < size);
317    apir_encode_blob_array(enc, val, size);
318}
319
320static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {
321    apir_decode_blob_array(dec, val, size);
322    if (size) {
323        val[size - 1] = '\0';
324    } else {
325        GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__);
326        apir_decoder_set_fatal(dec);
327    }
328}
329
330/* (temp) buffer allocation */
331
332static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
333    size_t alloc_size;
334    if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
335        GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n",
336                       __func__, size, count);
337        return NULL;
338    }
339
340    return malloc(alloc_size);
341}
342
343/* bool */
344
345static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {
346    apir_encode(enc, sizeof(int), val, sizeof(bool));
347}
348
349static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
350    apir_decode(dec, sizeof(int), val, sizeof(bool));
351}
352
353/* apir_buffer_type_host_handle_t */
354
355static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder *                  enc,
356                                                              const apir_buffer_type_host_handle_t * val) {
357    apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
358}
359
360static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder *            dec,
361                                                              apir_buffer_type_host_handle_t * val) {
362    apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
363}
364
365/* apir_buffer_host_handle_t */
366
367static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder *             enc,
368                                                         const apir_buffer_host_handle_t * val) {
369    apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
370}
371
372static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {
373    apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
374}
375
376/* uintptr_t */
377
378static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {
379    apir_encode(enc, sizeof(*val), val, sizeof(*val));
380}
381
382static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {
383    apir_decode(dec, sizeof(*val), val, sizeof(*val));
384}