1#include "backend-dispatched.h"
  2#include "backend-virgl-apir.h"
  3
  4#include "shared/api_remoting.h"
  5#include "shared/apir_backend.h"
  6#include "shared/apir_cs.h"
  7
  8#include <dlfcn.h>
  9#include <ggml-backend.h>
 10
 11#include <iostream>
 12
 13#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH"
 14#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV  "APIR_LLAMA_CPP_GGML_LIBRARY_REG"
 15#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV       "APIR_LLAMA_CPP_LOG_TO_FILE"
 16
 17#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
 18
 19static void * backend_library_handle = NULL;
 20static FILE * apir_logfile = NULL;
 21
 22static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
 23    FILE * logfile = (FILE *)user_data;
 24    fprintf(logfile, "[%d] %s", level, text);
 25    fflush(logfile);
 26}
 27
 28extern "C" {
 29void apir_backend_deinit(uint32_t virgl_ctx_id) {
 30    GGML_UNUSED(virgl_ctx_id);
 31
 32    auto buffers = apir_get_track_backend_buffers();
 33    for (const auto & buffer : buffers) {
 34        apir_untrack_backend_buffer(buffer);
 35        buffer->iface.free_buffer(buffer);
 36    }
 37
 38    if (backend_library_handle) {
 39        GGML_LOG_INFO(GGML_VIRTGPU_BCK "The GGML backend library was loaded. Unloading it.\n");
 40        dlclose(backend_library_handle);
 41        backend_library_handle = NULL;
 42    }
 43
 44    if (apir_logfile) {
 45        fclose(apir_logfile);
 46        apir_logfile = NULL;
 47    }
 48}
 49
 50#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
 51#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
 52
 53ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
 54    const char * dlsym_error;
 55
 56    const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
 57    if (apir_log_to_file) {
 58        apir_logfile = fopen(apir_log_to_file, "w");
 59        if (apir_logfile) {
 60            ggml_log_set(log_to_file_callback, apir_logfile);
 61        } else {
 62            GGML_LOG_INFO(GGML_VIRTGPU_BCK "Could not open the log file at '%s'\n", apir_log_to_file);
 63        }
 64    }
 65
 66    const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
 67    const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
 68    const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
 69
 70    if (!library_name) {
 71        GGML_LOG_ERROR(GGML_VIRTGPU_BCK
 72                       "%s: cannot open the GGML library: env var '%s' not defined\n",
 73                       __func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
 74
 75
 76        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
 77    }
 78
 79    backend_library_handle = dlopen(library_name, RTLD_LAZY);
 80
 81    if (!backend_library_handle) {
 82        GGML_LOG_ERROR(GGML_VIRTGPU_BCK
 83                       "%s: cannot open the GGML library: %s\n", __func__, dlerror());
 84
 85        return APIR_LOAD_LIBRARY_CANNOT_OPEN;
 86    }
 87
 88    if (!library_reg) {
 89        GGML_LOG_ERROR(GGML_VIRTGPU_BCK
 90                       "%s: cannot register the GGML library: env var '%s' not defined\n",
 91                       __func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
 92
 93        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
 94    }
 95
 96    void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
 97    dlsym_error                 = dlerror();
 98    if (dlsym_error) {
 99        GGML_LOG_ERROR(GGML_VIRTGPU_BCK
100                       "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
101                       __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
102
103
104        return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
105    }
106
107    uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct);
108
109    return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret);
110}
111
112uint32_t apir_backend_dispatcher(uint32_t               virgl_ctx_id,
113                                 virgl_apir_callbacks * virgl_cbs,
114                                 uint32_t               cmd_type,
115                                 char *                 dec_cur,
116                                 const char *           dec_end,
117                                 char *                 enc_cur,
118                                 const char *           enc_end,
119                                 char **                enc_cur_after) {
120    apir_encoder enc = {
121        .cur   = enc_cur,
122        .start = enc_cur,
123        .end   = enc_end,
124        .fatal = false,
125    };
126
127    apir_decoder dec = {
128        .cur   = dec_cur,
129        .end   = dec_end,
130        .fatal = false,
131    };
132
133    virgl_apir_context ctx = {
134        .ctx_id = virgl_ctx_id,
135        .iface = virgl_cbs,
136    };
137
138    if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
139        GGML_LOG_ERROR(GGML_VIRTGPU_BCK
140                       "%s: Received an invalid dispatch index (%d >= %d)\n",
141                        __func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
142        return APIR_BACKEND_FORWARD_INDEX_INVALID;
143    }
144
145    backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type];
146    uint32_t           ret         = forward_fct(&enc, &dec, &ctx);
147
148    *enc_cur_after = enc.cur;
149
150    return ret;
151}
152}