1#include "backend-dispatched.h"
  2#include "backend-virgl-apir.h"
  3#include "ggml-backend-impl.h"
  4#include "ggml-backend.h"
  5#include "ggml-impl.h"
  6
  7#include <cstdint>
  8
  9uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
 10    GGML_UNUSED(ctx);
 11    ggml_backend_buffer_t buffer;
 12    buffer = apir_decode_ggml_buffer(dec);
 13
 14    uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
 15    apir_encode_uintptr_t(enc, &base);
 16
 17    return 0;
 18}
 19
 20uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
 21    GGML_UNUSED(ctx);
 22    GGML_UNUSED(enc);
 23
 24    ggml_backend_buffer_t buffer;
 25    buffer = apir_decode_ggml_buffer(dec);
 26
 27    ggml_tensor * tensor;
 28    // safe to remove the const qualifier here
 29    tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
 30
 31    uint32_t shmem_res_id;
 32    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
 33
 34    size_t offset;
 35    apir_decode_size_t(dec, &offset);
 36
 37    size_t size;
 38    apir_decode_size_t(dec, &size);
 39
 40    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
 41
 42    if (!shmem_data) {
 43        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
 44        return 1;
 45    }
 46
 47    buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size);
 48
 49    return 0;
 50}
 51
 52uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
 53    GGML_UNUSED(ctx);
 54    GGML_UNUSED(enc);
 55
 56    ggml_backend_buffer_t buffer;
 57    buffer = apir_decode_ggml_buffer(dec);
 58
 59    const ggml_tensor * tensor;
 60    // safe to remove the const qualifier here
 61    tensor = apir_decode_ggml_tensor(dec);
 62
 63    uint32_t shmem_res_id;
 64    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
 65
 66    size_t offset;
 67    apir_decode_size_t(dec, &offset);
 68
 69    size_t size;
 70    apir_decode_size_t(dec, &size);
 71
 72    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
 73    if (!shmem_data) {
 74        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
 75        return 1;
 76    }
 77
 78    buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size);
 79
 80    return 0;
 81}
 82
 83uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
 84    GGML_UNUSED(ctx);
 85
 86    ggml_backend_buffer_t buffer;
 87    buffer = apir_decode_ggml_buffer(dec);
 88
 89    const ggml_tensor * src;
 90    // safe to remove the const qualifier here
 91    src               = apir_decode_ggml_tensor(dec);
 92    ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
 93
 94    bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst);
 95
 96    apir_encode_bool_t(enc, &ret);
 97
 98    return 0;
 99}
100
101uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
102    GGML_UNUSED(ctx);
103    GGML_UNUSED(enc);
104
105    ggml_backend_buffer_t buffer;
106    buffer = apir_decode_ggml_buffer(dec);
107
108    uint8_t value;
109    apir_decode_uint8_t(dec, &value);
110
111    buffer->iface.clear(buffer, value);
112
113    return 0;
114}
115
116uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
117    GGML_UNUSED(ctx);
118    GGML_UNUSED(enc);
119
120    ggml_backend_buffer_t buffer;
121    buffer = apir_decode_ggml_buffer(dec);
122
123    if (!apir_untrack_backend_buffer(buffer)) {
124        GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer);
125        return 1;
126    }
127
128    buffer->iface.free_buffer(buffer);
129
130    return 0;
131}