1#include "backend-dispatched.h"
 2#include "backend-virgl-apir.h"
 3#include "ggml-backend-impl.h"
 4#include "ggml-backend.h"
 5#include "ggml-impl.h"
 6#include "shared/apir_backend.h"
 7
 8#include <cstdint>
 9
10uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
11    GGML_UNUSED(ctx);
12    GGML_UNUSED(enc);
13
14    static bool async_backend_initialized = false;
15    static bool async_backend;
16
17    if (!async_backend_initialized) {
18        ggml_backend_dev_props props;
19
20        dev->iface.get_props(dev, &props);
21        async_backend             = props.caps.async;
22        async_backend_initialized = true;
23    }
24
25    uint32_t shmem_res_id;
26    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
27
28    const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
29    if (!shmem_data) {
30        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
31        apir_decoder_set_fatal(dec);
32        return 1;
33    }
34    size_t cgraph_size;
35    apir_decode_size_t(dec, &cgraph_size);
36
37    apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
38
39    ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
40
41    ggml_status status;
42#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
43    for (int idx = 0; idx < cgraph->n_nodes; idx++) {
44        ggml_tensor * op = ggml_graph_node(cgraph, idx);
45        if (dev->iface.supports_op(dev, op)) {
46            continue;
47        }
48        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
49
50        status = GGML_STATUS_ABORTED;
51        apir_encode_ggml_status(enc, &status);
52
53        return 0;
54    }
55#endif
56    status = bck->iface.graph_compute(bck, cgraph);
57
58    if (async_backend) {
59        bck->iface.synchronize(bck);
60    }
61
62    apir_encode_ggml_status(enc, &status);
63
64    return 0;
65}