1#include "ggml-rpc.h"
   2#include "ggml-impl.h"
   3#include "ggml-backend-impl.h"
   4#include "ggml-cpp.h"
   5
   6#include <cinttypes>
   7#include <string>
   8#include <vector>
   9#include <memory>
  10#include <mutex>
  11#include <unordered_map>
  12#include <unordered_set>
  13#ifdef _WIN32
  14#  define WIN32_LEAN_AND_MEAN
  15#  ifndef NOMINMAX
  16#     define NOMINMAX
  17#  endif
  18#  include <windows.h>
  19#  include <winsock2.h>
  20#else
  21#  include <arpa/inet.h>
  22#  include <sys/socket.h>
  23#  include <sys/types.h>
  24#  include <netinet/in.h>
  25#  include <netinet/tcp.h>
  26#  include <netdb.h>
  27#  include <unistd.h>
  28#endif
  29#include <cstring>
  30#include <fstream>
  31#include <filesystem>
  32#include <algorithm>
  33
  34static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
  35
  36#define LOG_DBG(...) \
  37    do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
  38
  39
  40namespace fs = std::filesystem;
  41
  42static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
  43
  44#ifdef _WIN32
  45typedef SOCKET sockfd_t;
  46using ssize_t = __int64;
  47#else
  48typedef int sockfd_t;
  49#endif
  50
  51// cross-platform socket
  52struct socket_t {
  53    sockfd_t fd;
  54    socket_t(sockfd_t fd) : fd(fd) {}
  55    ~socket_t() {
  56        LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
  57#ifdef _WIN32
  58        closesocket(this->fd);
  59#else
  60        close(this->fd);
  61#endif
  62    }
  63};
  64
  65// macro for nicer error messages on server crash
  66#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
  67
  68// all RPC structures must be packed
  69#pragma pack(push, 1)
  70// ggml_tensor is serialized into rpc_tensor
  71struct rpc_tensor {
  72    uint64_t id;
  73    uint32_t type;
  74    uint64_t buffer;
  75    uint32_t ne[GGML_MAX_DIMS];
  76    uint32_t nb[GGML_MAX_DIMS];
  77    uint32_t op;
  78    int32_t  op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  79    int32_t  flags;
  80    uint64_t src[GGML_MAX_SRC];
  81    uint64_t view_src;
  82    uint64_t view_offs;
  83    uint64_t data;
  84    char name[GGML_MAX_NAME];
  85
  86    char padding[4];
  87};
  88
  89static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
  90
  91// RPC commands
  92enum rpc_cmd {
  93    RPC_CMD_ALLOC_BUFFER = 0,
  94    RPC_CMD_GET_ALIGNMENT,
  95    RPC_CMD_GET_MAX_SIZE,
  96    RPC_CMD_BUFFER_GET_BASE,
  97    RPC_CMD_FREE_BUFFER,
  98    RPC_CMD_BUFFER_CLEAR,
  99    RPC_CMD_SET_TENSOR,
 100    RPC_CMD_SET_TENSOR_HASH,
 101    RPC_CMD_GET_TENSOR,
 102    RPC_CMD_COPY_TENSOR,
 103    RPC_CMD_GRAPH_COMPUTE,
 104    RPC_CMD_GET_DEVICE_MEMORY,
 105    RPC_CMD_INIT_TENSOR,
 106    RPC_CMD_GET_ALLOC_SIZE,
 107    RPC_CMD_HELLO,
 108    RPC_CMD_DEVICE_COUNT,
 109    RPC_CMD_GRAPH_RECOMPUTE,
 110    RPC_CMD_COUNT,
 111};
 112
 113static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
 114
 115// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
 116const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
 117
 118struct rpc_msg_hello_rsp {
 119    uint8_t major;
 120    uint8_t minor;
 121    uint8_t patch;
 122};
 123
 124struct rpc_msg_device_count_rsp {
 125    uint32_t device_count;
 126};
 127
 128struct rpc_msg_get_alloc_size_req {
 129    uint32_t   device;
 130    rpc_tensor tensor;
 131    rpc_tensor srcs[GGML_MAX_SRC];
 132};
 133
 134struct rpc_msg_get_alloc_size_rsp {
 135    uint64_t alloc_size;
 136};
 137
 138struct rpc_msg_init_tensor_req {
 139    rpc_tensor tensor;
 140};
 141
 142struct rpc_msg_alloc_buffer_req {
 143    uint32_t device;
 144    uint64_t size;
 145};
 146
 147struct rpc_msg_alloc_buffer_rsp {
 148    uint64_t remote_ptr;
 149    uint64_t remote_size;
 150};
 151
 152struct rpc_msg_get_alignment_req {
 153    uint32_t device;
 154};
 155
 156struct rpc_msg_get_alignment_rsp {
 157    uint64_t alignment;
 158};
 159
 160struct rpc_msg_get_max_size_req {
 161    uint32_t device;
 162};
 163
 164struct rpc_msg_get_max_size_rsp {
 165    uint64_t max_size;
 166};
 167
 168struct rpc_msg_buffer_get_base_req {
 169    uint64_t remote_ptr;
 170};
 171
 172struct rpc_msg_buffer_get_base_rsp {
 173    uint64_t base_ptr;
 174};
 175
 176struct rpc_msg_free_buffer_req {
 177    uint64_t remote_ptr;
 178};
 179
 180struct rpc_msg_buffer_clear_req {
 181    uint64_t remote_ptr;
 182    uint8_t value;
 183};
 184
 185struct rpc_msg_set_tensor_hash_req {
 186    rpc_tensor tensor;
 187    uint64_t offset;
 188    uint64_t hash;
 189};
 190
 191struct rpc_msg_set_tensor_hash_rsp {
 192    uint8_t result;
 193};
 194
 195struct rpc_msg_get_tensor_req {
 196    rpc_tensor tensor;
 197    uint64_t offset;
 198    uint64_t size;
 199};
 200
 201struct rpc_msg_copy_tensor_req {
 202    rpc_tensor src;
 203    rpc_tensor dst;
 204};
 205
 206struct rpc_msg_copy_tensor_rsp {
 207    uint8_t result;
 208};
 209
 210struct rpc_msg_get_device_memory_req {
 211    uint32_t device;
 212};
 213
 214struct rpc_msg_get_device_memory_rsp {
 215    uint64_t free_mem;
 216    uint64_t total_mem;
 217};
 218
 219struct rpc_msg_graph_recompute_req {
 220    uint32_t device;
 221};
 222
 223#pragma pack(pop)
 224
 225// RPC data structures
 226
 227static ggml_guid_t ggml_backend_rpc_guid() {
 228    static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
 229    return &guid;
 230}
 231
 232struct ggml_backend_rpc_buffer_type_context {
 233    std::string endpoint;
 234    uint32_t    device;
 235    std::string name;
 236    size_t      alignment;
 237    size_t      max_size;
 238};
 239
 240struct graph_cache {
 241
 242    bool is_cached(const ggml_cgraph * cgraph) {
 243        if ((int)last_graph.size() != cgraph->n_nodes) {
 244            return false;
 245        }
 246        for (int i = 0; i < cgraph->n_nodes; i++) {
 247            if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
 248                return false;
 249            }
 250        }
 251        return true;
 252    }
 253
 254    void add(const ggml_cgraph * cgraph) {
 255        last_graph.resize(cgraph->n_nodes);
 256        for (int i = 0; i < cgraph->n_nodes; i++) {
 257            memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
 258        }
 259    }
 260
 261    std::vector<ggml_tensor> last_graph;
 262};
 263
 264struct ggml_backend_rpc_context {
 265    std::string endpoint;
 266    uint32_t    device;
 267    std::string name;
 268    graph_cache gc;
 269};
 270
 271struct ggml_backend_rpc_buffer_context {
 272    std::shared_ptr<socket_t> sock;
 273    void * base_ptr;
 274    uint64_t remote_ptr;
 275};
 276
 277// RPC helper functions
 278
 279// Computes FNV-1a hash of the data
 280static uint64_t fnv_hash(const uint8_t * data, size_t len) {
 281    const uint64_t fnv_prime = 0x100000001b3ULL;
 282    uint64_t hash = 0xcbf29ce484222325ULL;
 283
 284    for (size_t i = 0; i < len; ++i) {
 285        hash ^= data[i];
 286        hash *= fnv_prime;
 287    }
 288    return hash;
 289}
 290
 291static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
 292#ifdef _WIN32
 293    if (fd == INVALID_SOCKET) {
 294        return nullptr;
 295    }
 296#else
 297    if (fd < 0) {
 298        return nullptr;
 299    }
 300#endif
 301    return std::make_shared<socket_t>(fd);
 302}
 303
 304static bool set_no_delay(sockfd_t sockfd) {
 305    int flag = 1;
 306    // set TCP_NODELAY to disable Nagle's algorithm
 307    int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
 308    return ret == 0;
 309}
 310
 311static bool set_reuse_addr(sockfd_t sockfd) {
 312    int flag = 1;
 313    int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
 314    return ret == 0;
 315}
 316
 317static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
 318    struct sockaddr_in addr;
 319    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
 320    auto sock_ptr = make_socket(sockfd);
 321    if (sock_ptr == nullptr) {
 322        return nullptr;
 323    }
 324    if (!set_no_delay(sockfd)) {
 325        GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
 326        return nullptr;
 327    }
 328    addr.sin_family = AF_INET;
 329    addr.sin_port = htons(port);
 330    struct hostent * server = gethostbyname(host);
 331    if (server == NULL) {
 332        GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
 333        return nullptr;
 334    }
 335    memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
 336    if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
 337        return nullptr;
 338    }
 339    return sock_ptr;
 340}
 341
 342static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
 343    auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
 344    auto client_socket = make_socket(client_socket_fd);
 345    if (client_socket == nullptr) {
 346        return nullptr;
 347    }
 348    if (!set_no_delay(client_socket_fd)) {
 349        GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
 350        return nullptr;
 351    }
 352    return client_socket;
 353}
 354
 355static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
 356    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
 357    auto sock = make_socket(sockfd);
 358    if (sock == nullptr) {
 359        return nullptr;
 360    }
 361    if (!set_reuse_addr(sockfd)) {
 362        GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
 363        return nullptr;
 364    }
 365    if (inet_addr(host) == INADDR_NONE) {
 366        GGML_LOG_ERROR("Invalid host address: %s\n", host);
 367        return nullptr;
 368    }
 369    struct sockaddr_in serv_addr;
 370    serv_addr.sin_family = AF_INET;
 371    serv_addr.sin_addr.s_addr = inet_addr(host);
 372    serv_addr.sin_port = htons(port);
 373
 374    if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
 375        return nullptr;
 376    }
 377    if (listen(sockfd, 1) < 0) {
 378        return nullptr;
 379    }
 380    return sock;
 381}
 382
 383static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
 384    size_t bytes_sent = 0;
 385    while (bytes_sent < size) {
 386        size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
 387        ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
 388        if (n < 0) {
 389            GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
 390                           bytes_sent, size_to_send);
 391            return false;
 392        }
 393        bytes_sent += (size_t)n;
 394    }
 395    return true;
 396}
 397
 398static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
 399    size_t bytes_recv = 0;
 400    while (bytes_recv < size) {
 401        size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
 402        ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
 403        if (n < 0) {
 404            GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
 405                           bytes_recv, size_to_recv);
 406            return false;
 407        }
 408        if (n == 0) {
 409            LOG_DBG("recv returned 0 (peer closed?)\n");
 410            return false;
 411        }
 412        bytes_recv += (size_t)n;
 413    }
 414    return true;
 415}
 416
 417static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
 418    if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
 419        return false;
 420    }
 421    return send_data(sockfd, msg, msg_size);
 422}
 423
 424static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
 425    uint64_t size;
 426    if (!recv_data(sockfd, &size, sizeof(size))) {
 427        return false;
 428    }
 429    if (size != msg_size) {
 430        return false;
 431    }
 432    return recv_data(sockfd, msg, msg_size);
 433}
 434
 435static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
 436    uint64_t size;
 437    if (!recv_data(sockfd, &size, sizeof(size))) {
 438        return false;
 439    }
 440    try {
 441        input.resize(size);
 442    } catch (const std::bad_alloc & e) {
 443        GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
 444        return false;
 445    }
 446    return recv_data(sockfd, input.data(), size);
 447}
 448
 449static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
 450    size_t pos = endpoint.find(':');
 451    if (pos == std::string::npos) {
 452        return false;
 453    }
 454    host = endpoint.substr(0, pos);
 455    port = std::stoi(endpoint.substr(pos + 1));
 456    return true;
 457}
 458
 459// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
 460// No response
 461static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
 462    uint8_t cmd_byte = cmd;
 463    if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
 464        return false;
 465    }
 466    if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
 467        return false;
 468    }
 469    if (!send_data(sock->fd, input, input_size)) {
 470        return false;
 471    }
 472    return true;
 473}
 474
 475// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
 476// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
 477static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
 478    if (!send_rpc_cmd(sock, cmd, input, input_size)) {
 479        return false;
 480    }
 481    // TODO: currently the output_size is always known, do we need support for commands with variable output size?
 482    // even if we do, we can skip sending output_size from the server for commands with known output size
 483    uint64_t out_size;
 484    if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
 485        return false;
 486    }
 487    if (out_size != output_size) {
 488        return false;
 489    }
 490    if (!recv_data(sock->fd, output, output_size)) {
 491        return false;
 492    }
 493    return true;
 494}
 495
 496// RPC client-side implementation
 497
 498static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
 499    rpc_msg_hello_rsp response;
 500    bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
 501    RPC_STATUS_ASSERT(status);
 502    if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
 503        GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
 504        return false;
 505    }
 506    if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
 507        GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
 508    }
 509    return true;
 510}
 511
 512static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
 513    static std::mutex mutex;
 514    std::lock_guard<std::mutex> lock(mutex);
 515    static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
 516    static bool initialized = false;
 517
 518    auto it = sockets.find(endpoint);
 519    if (it != sockets.end()) {
 520        if (auto sock = it->second.lock()) {
 521            return sock;
 522        }
 523    }
 524    std::string host;
 525    int port;
 526    if (!parse_endpoint(endpoint, host, port)) {
 527        GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
 528        return nullptr;
 529    }
 530#ifdef _WIN32
 531    if (!initialized) {
 532        WSADATA wsaData;
 533        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
 534        if (res != 0) {
 535            return nullptr;
 536        }
 537        initialized = true;
 538    }
 539#else
 540    GGML_UNUSED(initialized);
 541#endif
 542    auto sock = socket_connect(host.c_str(), port);
 543    if (sock == nullptr) {
 544        return nullptr;
 545    }
 546    if (!check_server_version(sock)) {
 547        return nullptr;
 548    }
 549    LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
 550    sockets[endpoint] = sock;
 551    return sock;
 552}
 553
 554static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 555    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 556    rpc_msg_free_buffer_req request = {ctx->remote_ptr};
 557    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
 558    RPC_STATUS_ASSERT(status);
 559    delete ctx;
 560}
 561
 562static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
 563    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 564    if (ctx->base_ptr != nullptr) {
 565        return ctx->base_ptr;
 566    }
 567    rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
 568    rpc_msg_buffer_get_base_rsp response;
 569    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
 570    RPC_STATUS_ASSERT(status);
 571    ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
 572    return ctx->base_ptr;
 573}
 574
 575static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
 576    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
 577}
 578
 579static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
 580    rpc_tensor result;
 581    if (!tensor) {
 582        memset(&result, 0, sizeof(result));
 583        return result;
 584    }
 585
 586    result.id = reinterpret_cast<uint64_t>(tensor);
 587    result.type = tensor->type;
 588    if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
 589        ggml_backend_buffer_t buffer = tensor->buffer;
 590        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 591        result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
 592    } else {
 593        result.buffer = 0;
 594    }
 595    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
 596        result.ne[i] = tensor->ne[i];
 597        result.nb[i] = tensor->nb[i];
 598    }
 599    result.op = tensor->op;
 600    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
 601        result.op_params[i] = tensor->op_params[i];
 602    }
 603    result.flags = tensor->flags;
 604    for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
 605        result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
 606    }
 607    result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
 608    result.view_offs = tensor->view_offs;
 609    result.data = reinterpret_cast<uint64_t>(tensor->data);
 610
 611    // Avoid sending uninitialized data over the wire
 612    memset(result.name, 0, sizeof(result.name));
 613    memset(result.padding, 0, sizeof(result.padding));
 614
 615    snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
 616    return result;
 617}
 618
 619static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
 620    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 621
 622    // CUDA backend on the server pads everything to 512 due to CUDA limitations.
 623    // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
 624    // In particular, only quantized tensors need padding
 625    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
 626        rpc_msg_init_tensor_req request;
 627
 628        request.tensor = serialize_tensor(tensor);
 629
 630        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
 631        RPC_STATUS_ASSERT(status);
 632    }
 633    return GGML_STATUS_SUCCESS;
 634}
 635
 636static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 637    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 638    rpc_tensor rpc_tensor = serialize_tensor(tensor);
 639    if (size > HASH_THRESHOLD) {
 640        rpc_msg_set_tensor_hash_req request;
 641        request.tensor = rpc_tensor;
 642        request.offset = offset;
 643        request.hash = fnv_hash((const uint8_t*)data, size);
 644        rpc_msg_set_tensor_hash_rsp response;
 645        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
 646        RPC_STATUS_ASSERT(status);
 647        if (response.result) {
 648            // the server has the same data, no need to send it
 649            return;
 650        }
 651    }
 652    // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
 653    size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
 654    std::vector<uint8_t> input(input_size, 0);
 655    memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
 656    memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
 657    memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
 658    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
 659    RPC_STATUS_ASSERT(status);
 660}
 661
 662static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 663    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 664    rpc_msg_get_tensor_req request;
 665    request.tensor = serialize_tensor(tensor);
 666    request.offset = offset;
 667    request.size = size;
 668    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
 669    RPC_STATUS_ASSERT(status);
 670}
 671
 672static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
 673    if (ggml_backend_buffer_is_rpc(src->buffer)) {
 674        // check if src and dst are on the same server
 675        ggml_backend_buffer_t src_buffer = src->buffer;
 676        ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
 677        ggml_backend_buffer_t dst_buffer = dst->buffer;
 678        ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
 679        if (src_ctx->sock != dst_ctx->sock) {
 680            return false;
 681        }
 682        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 683        rpc_msg_copy_tensor_req request;
 684        request.src = serialize_tensor(src);
 685        request.dst = serialize_tensor(dst);
 686        rpc_msg_copy_tensor_rsp response;
 687        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
 688        RPC_STATUS_ASSERT(status);
 689        return response.result;
 690    }
 691    return false;
 692}
 693
 694static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
 695    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 696    rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
 697    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
 698    RPC_STATUS_ASSERT(status);
 699}
 700
 701static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
 702    /* .free_buffer     = */ ggml_backend_rpc_buffer_free_buffer,
 703    /* .get_base        = */ ggml_backend_rpc_buffer_get_base,
 704    /* .init_tensor     = */ ggml_backend_rpc_buffer_init_tensor,
 705    /* .memset_tensor   = */ NULL,
 706    /* .set_tensor      = */ ggml_backend_rpc_buffer_set_tensor,
 707    /* .get_tensor      = */ ggml_backend_rpc_buffer_get_tensor,
 708    /* .cpy_tensor      = */ ggml_backend_rpc_buffer_cpy_tensor,
 709    /* .clear           = */ ggml_backend_rpc_buffer_clear,
 710    /* .reset           = */ NULL,
 711};
 712
 713static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
 714    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 715    return buft_ctx->name.c_str();
 716}
 717
 718static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 719    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 720    rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
 721    rpc_msg_alloc_buffer_rsp response;
 722    auto sock = get_socket(buft_ctx->endpoint);
 723    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
 724    RPC_STATUS_ASSERT(status);
 725    if (response.remote_ptr != 0) {
 726        ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
 727            ggml_backend_rpc_buffer_interface,
 728            new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
 729            response.remote_size);
 730        return buffer;
 731    } else {
 732        return nullptr;
 733    }
 734}
 735
 736static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
 737    rpc_msg_get_alignment_req request = {device};
 738    rpc_msg_get_alignment_rsp response;
 739    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
 740    RPC_STATUS_ASSERT(status);
 741    return response.alignment;
 742}
 743
 744static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
 745    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 746    return buft_ctx->alignment;
 747}
 748
 749static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
 750    rpc_msg_get_max_size_req request = {device};
 751    rpc_msg_get_max_size_rsp response;
 752    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
 753    RPC_STATUS_ASSERT(status);
 754    return response.max_size;
 755}
 756
 757static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
 758    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 759    return buft_ctx->max_size;
 760}
 761
 762static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
 763    // should we query the remote server for the actual size
 764    bool rpc_get = false;
 765
 766    // See comments in init_tensor.
 767    rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
 768
 769    // ops that require additional memory for fleeting data on certain backends
 770    // ref: https://github.com/ggml-org/llama.cpp/pull/15966
 771    rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
 772    rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
 773
 774    if (rpc_get) {
 775        ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 776        auto sock = get_socket(buft_ctx->endpoint);
 777
 778        rpc_msg_get_alloc_size_req request = {
 779            /*.device =*/ buft_ctx->device,
 780            /*.tensor =*/ serialize_tensor(tensor),
 781            /*.srcs   =*/ {},
 782        };
 783
 784        // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
 785        for (int i = 0; i < GGML_MAX_SRC; i++) {
 786            request.srcs[i] = serialize_tensor(tensor->src[i]);
 787        }
 788
 789        // TODO: cache the alloc responses to avoid extra RPC calls?
 790        rpc_msg_get_alloc_size_rsp response;
 791        bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
 792        RPC_STATUS_ASSERT(status);
 793
 794        return response.alloc_size;
 795    }
 796
 797    return ggml_nbytes(tensor);
 798}
 799
 800static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
 801    /* .get_name         = */ ggml_backend_rpc_buffer_type_name,
 802    /* .alloc_buffer     = */ ggml_backend_rpc_buffer_type_alloc_buffer,
 803    /* .get_alignment    = */ ggml_backend_rpc_buffer_type_get_alignment,
 804    /* .get_max_size     = */ ggml_backend_rpc_get_max_size,
 805    /* .get_alloc_size   = */ ggml_backend_rpc_buffer_type_get_alloc_size,
 806    /* .is_host          = */ NULL,
 807};
 808
 809static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
 810    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 811
 812    return rpc_ctx->name.c_str();
 813}
 814
 815static void ggml_backend_rpc_free(ggml_backend_t backend) {
 816    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 817    delete rpc_ctx;
 818    delete backend;
 819}
 820
 821static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
 822    GGML_UNUSED(backend);
 823    // this is no-op because we don't have any async operations
 824}
 825
 826static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
 827    if (tensor == nullptr) {
 828        return;
 829    }
 830    if (visited.find(tensor) != visited.end()) {
 831        return;
 832    }
 833    visited.insert(tensor);
 834    for (int i = 0; i < GGML_MAX_SRC; i++) {
 835        add_tensor(tensor->src[i], tensors, visited);
 836    }
 837    add_tensor(tensor->view_src, tensors, visited);
 838    tensors.push_back(serialize_tensor(tensor));
 839}
 840
 841static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
 842    uint32_t n_nodes = cgraph->n_nodes;
 843    std::vector<rpc_tensor> tensors;
 844    std::unordered_set<ggml_tensor*> visited;
 845    for (uint32_t i = 0; i < n_nodes; i++) {
 846        add_tensor(cgraph->nodes[i], tensors, visited);
 847    }
 848    // serialization format:
 849    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
 850    uint32_t n_tensors = tensors.size();
 851    int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
 852    output.resize(output_size, 0);
 853    uint8_t * dest = output.data();
 854    memcpy(dest, &device, sizeof(device));
 855    dest += sizeof(device);
 856    memcpy(dest, &n_nodes, sizeof(n_nodes));
 857    dest += sizeof(n_nodes);
 858    for (uint32_t i = 0; i < n_nodes; i++) {
 859        memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
 860    }
 861    dest += n_nodes * sizeof(uint64_t);
 862    memcpy(dest, &n_tensors, sizeof(n_tensors));
 863    dest += sizeof(n_tensors);
 864    rpc_tensor * out_tensors = (rpc_tensor *)dest;
 865    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
 866}
 867
 868static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
 869    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 870
 871    GGML_ASSERT(cgraph->n_nodes > 0);
 872    bool reuse = rpc_ctx->gc.is_cached(cgraph);
 873    if (reuse) {
 874        rpc_msg_graph_recompute_req request;
 875        request.device = rpc_ctx->device;
 876        auto sock = get_socket(rpc_ctx->endpoint);
 877        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
 878        RPC_STATUS_ASSERT(status);
 879    } else {
 880        rpc_ctx->gc.add(cgraph);
 881        std::vector<uint8_t> input;
 882        serialize_graph(rpc_ctx->device, cgraph, input);
 883        auto sock = get_socket(rpc_ctx->endpoint);
 884        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
 885        RPC_STATUS_ASSERT(status);
 886    }
 887    return GGML_STATUS_SUCCESS;
 888}
 889
 890static ggml_backend_i ggml_backend_rpc_interface = {
 891    /* .get_name                = */ ggml_backend_rpc_name,
 892    /* .free                    = */ ggml_backend_rpc_free,
 893    /* .set_tensor_async        = */ NULL,
 894    /* .get_tensor_async        = */ NULL,
 895    /* .cpy_tensor_async        = */ NULL,
 896    /* .synchronize             = */ ggml_backend_rpc_synchronize,
 897    /* .graph_plan_create       = */ NULL,
 898    /* .graph_plan_free         = */ NULL,
 899    /* .graph_plan_update       = */ NULL,
 900    /* .graph_plan_compute      = */ NULL,
 901    /* .graph_compute           = */ ggml_backend_rpc_graph_compute,
 902    /* .event_record            = */ NULL,
 903    /* .event_wait              = */ NULL,
 904    /* .graph_optimize          = */ NULL,
 905};
 906
 907ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
 908    static std::mutex mutex;
 909    std::lock_guard<std::mutex> lock(mutex);
 910    std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
 911    // NOTE: buffer types are allocated and never freed; this is by design
 912    static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
 913    auto it = buft_map.find(buft_name);
 914    if (it != buft_map.end()) {
 915        return it->second;
 916    }
 917    auto sock = get_socket(endpoint);
 918    if (sock == nullptr) {
 919        GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
 920        return nullptr;
 921    }
 922    size_t alignment = get_alignment(sock, device);
 923    size_t max_size = get_max_size(sock, device);
 924    ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
 925        /* .endpoint  = */ endpoint,
 926        /* .device    = */ device,
 927        /* .name      = */ buft_name,
 928        /* .alignment = */ alignment,
 929        /* .max_size  = */ max_size
 930    };
 931    auto reg = ggml_backend_rpc_add_server(endpoint);
 932    ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
 933        /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
 934        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
 935        /* .context = */ buft_ctx
 936    };
 937    buft_map[buft_name] = buft;
 938    return buft;
 939}
 940
 941ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
 942    std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
 943    ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
 944        /* .endpoint = */ endpoint,
 945        /* .device   = */ device,
 946        /* .name     = */ dev_name,
 947        /* .gc       = */ {},
 948    };
 949    auto reg = ggml_backend_rpc_add_server(endpoint);
 950    ggml_backend_t backend = new ggml_backend {
 951        /* .guid    = */ ggml_backend_rpc_guid(),
 952        /* .iface   = */ ggml_backend_rpc_interface,
 953        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
 954        /* .context = */ ctx
 955    };
 956    return backend;
 957}
 958
 959bool ggml_backend_is_rpc(ggml_backend_t backend) {
 960    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
 961}
 962
 963static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
 964    rpc_msg_get_device_memory_req request;
 965    request.device = device;
 966    rpc_msg_get_device_memory_rsp response;
 967    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
 968    RPC_STATUS_ASSERT(status);
 969    *free = response.free_mem;
 970    *total = response.total_mem;
 971}
 972
 973void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
 974    auto sock = get_socket(endpoint);
 975    if (sock == nullptr) {
 976        *free = 0;
 977        *total = 0;
 978        return;
 979    }
 980    get_device_memory(sock, device, free, total);
 981}
 982
 983// RPC server-side implementation
 984
 985class rpc_server {
 986public:
 987    rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
 988        : backends(std::move(all_backends)), cache_dir(cache_dir) {
 989        stored_graphs.resize(backends.size());
 990    }
 991    ~rpc_server();
 992
 993    void hello(rpc_msg_hello_rsp & response);
 994    bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
 995    bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
 996    bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
 997    bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
 998    bool free_buffer(const rpc_msg_free_buffer_req & request);
 999    bool buffer_clear(const rpc_msg_buffer_clear_req & request);
1000    bool set_tensor(const std::vector<uint8_t> & input);
1001    bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
1002    bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
1003    bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
1004    bool graph_compute(const std::vector<uint8_t> & input);
1005    bool graph_recompute(const rpc_msg_graph_recompute_req & request);
1006    bool init_tensor(const rpc_msg_init_tensor_req & request);
1007    bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
1008    bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
1009
1010    struct stored_graph {
1011        ggml_context_ptr ctx_ptr;
1012        ggml_cgraph *    graph;
1013    };
1014
1015private:
1016    bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
1017    ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
1018    ggml_tensor * create_node(uint64_t id,
1019                              struct ggml_context * ctx,
1020                              const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1021                              std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
1022
1023
1024    std::vector<ggml_backend_t> backends;
1025    const char * cache_dir;
1026    std::unordered_set<ggml_backend_buffer_t> buffers;
1027    // store the last computed graph for each backend
1028    std::vector<stored_graph> stored_graphs;
1029};
1030
1031void rpc_server::hello(rpc_msg_hello_rsp & response) {
1032    response.major = RPC_PROTO_MAJOR_VERSION;
1033    response.minor = RPC_PROTO_MINOR_VERSION;
1034    response.patch = RPC_PROTO_PATCH_VERSION;
1035    LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
1036}
1037
1038bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
1039    uint32_t dev_id = request.device;
1040    if (dev_id >= backends.size()) {
1041        return false;
1042    }
1043    ggml_backend_buffer_type_t buft;
1044    struct ggml_init_params params {
1045        /*.mem_size   =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
1046        /*.mem_buffer =*/ NULL,
1047        /*.no_alloc   =*/ true,
1048    };
1049
1050    ggml_context_ptr ctx_ptr { ggml_init(params) };
1051    GGML_ASSERT(ctx_ptr != nullptr);
1052    ggml_context * ctx = ctx_ptr.get();
1053
1054    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1055    if (tensor == nullptr) {
1056        GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
1057        return false;
1058    }
1059    for (int i = 0; i < GGML_MAX_SRC; i++) {
1060        if (request.srcs[i].id != 0) {
1061            tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1062        }
1063    }
1064
1065    LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
1066    if (tensor->buffer == nullptr) {
1067        //No buffer allocated.
1068        buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1069    } else {
1070        buft = tensor->buffer->buft;
1071    }
1072
1073    response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
1074
1075    return true;
1076}
1077
1078bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
1079    uint32_t dev_id = request.device;
1080    if (dev_id >= backends.size()) {
1081        return false;
1082    }
1083    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1084    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
1085    response.remote_ptr = 0;
1086    response.remote_size = 0;
1087    if (buffer != nullptr) {
1088        response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
1089        response.remote_size = buffer->size;
1090        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
1091            __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
1092        buffers.insert(buffer);
1093    } else {
1094        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
1095    }
1096    return true;
1097}
1098
1099bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
1100    uint32_t dev_id = request.device;
1101    if (dev_id >= backends.size()) {
1102        return false;
1103    }
1104    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1105    size_t alignment = ggml_backend_buft_get_alignment(buft);
1106    LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
1107    response.alignment = alignment;
1108    return true;
1109}
1110
1111bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
1112    uint32_t dev_id = request.device;
1113    if (dev_id >= backends.size()) {
1114        return false;
1115    }
1116    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1117    size_t max_size = ggml_backend_buft_get_max_size(buft);
1118    LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
1119    response.max_size = max_size;
1120    return true;
1121}
1122
1123bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
1124    LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1125    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1126    if (buffers.find(buffer) == buffers.end()) {
1127        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1128        return false;
1129    }
1130    void * base = ggml_backend_buffer_get_base(buffer);
1131    response.base_ptr = reinterpret_cast<uint64_t>(base);
1132    return true;
1133}
1134
1135bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
1136    LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1137    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1138    if (buffers.find(buffer) == buffers.end()) {
1139        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1140        return false;
1141    }
1142    ggml_backend_buffer_free(buffer);
1143    buffers.erase(buffer);
1144    return true;
1145}
1146
1147bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
1148    LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
1149    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1150    if (buffers.find(buffer) == buffers.end()) {
1151        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1152        return false;
1153    }
1154    ggml_backend_buffer_clear(buffer, request.value);
1155    return true;
1156}
1157
1158ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
1159    // Validate tensor type before using it
1160    if (tensor->type >= GGML_TYPE_COUNT) {
1161        GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
1162        return nullptr;
1163    }
1164
1165    ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
1166        tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1167
1168    // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
1169    if (result == nullptr) {
1170        GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
1171        return nullptr;
1172    }
1173
1174    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
1175        result->nb[i] = tensor->nb[i];
1176    }
1177    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
1178    if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
1179        result->buffer = nullptr;
1180    }
1181
1182    if (result->buffer) {
1183        // require that the tensor data does not go beyond the buffer end
1184        uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
1185        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
1186        uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
1187        GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
1188        GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
1189    }
1190
1191    result->op = (ggml_op) tensor->op;
1192    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
1193        result->op_params[i] = tensor->op_params[i];
1194    }
1195    result->flags = tensor->flags;
1196    result->data = reinterpret_cast<void *>(tensor->data);
1197    ggml_set_name(result, tensor->name);
1198    return result;
1199}
1200
1201
1202bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
1203    // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
1204    if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
1205        return false;
1206    }
1207    const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1208    uint64_t offset;
1209    memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1210    const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
1211
1212    struct ggml_init_params params {
1213        /*.mem_size   =*/ ggml_tensor_overhead(),
1214        /*.mem_buffer =*/ NULL,
1215        /*.no_alloc   =*/ true,
1216    };
1217    ggml_context_ptr ctx_ptr { ggml_init(params) };
1218    GGML_ASSERT(ctx_ptr != nullptr);
1219    ggml_context * ctx = ctx_ptr.get();
1220    ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1221    if (tensor == nullptr || tensor->buffer == nullptr) {
1222        GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1223        return false;
1224    }
1225    LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
1226
1227    // sanitize tensor->data
1228    {
1229        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1230        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1231
1232        if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1233            GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1234                           __func__, in_tensor->data, offset, size, p0, p1);
1235            return false;
1236        }
1237    }
1238
1239    const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1240    if (cache_dir && size > HASH_THRESHOLD) {
1241        uint64_t hash = fnv_hash((const uint8_t*)data, size);
1242        char hash_str[17];
1243        snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1244        // save to cache_dir/hash_str
1245        fs::path cache_file = fs::path(cache_dir) / hash_str;
1246        std::ofstream ofs(cache_file, std::ios::binary);
1247        ofs.write((const char *)data, size);
1248        GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1249    }
1250    ggml_backend_tensor_set(tensor, data, offset, size);
1251    return true;
1252}
1253
1254bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1255    if (!cache_dir) {
1256        return false;
1257    }
1258    char hash_str[17];
1259    snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1260    fs::path cache_file = fs::path(cache_dir) / hash_str;
1261    std::error_code ec;
1262    if (!fs::exists(cache_file, ec)) {
1263        return false;
1264    }
1265    std::ifstream ifs(cache_file, std::ios::binary);
1266    ifs.seekg(0, std::ios::end);
1267    size_t size = ifs.tellg();
1268    ifs.seekg(0, std::ios::beg);
1269    data.resize(size);
1270    ifs.read((char *)data.data(), size);
1271    return true;
1272}
1273
1274bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
1275{
1276    std::vector<uint8_t> cached_file;
1277    if (!get_cached_file(request.hash, cached_file)) {
1278        response.result = 0;
1279        return true;
1280    }
1281    size_t size = cached_file.size();
1282    struct ggml_init_params params {
1283        /*.mem_size   =*/ ggml_tensor_overhead(),
1284        /*.mem_buffer =*/ NULL,
1285        /*.no_alloc   =*/ true,
1286    };
1287    ggml_context_ptr ctx_ptr { ggml_init(params) };
1288    GGML_ASSERT(ctx_ptr != nullptr);
1289    ggml_context * ctx = ctx_ptr.get();
1290    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1291    if (tensor == nullptr || tensor->buffer == nullptr) {
1292        GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1293        return false;
1294    }
1295    LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1296            __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1297
1298    // sanitize tensor->data
1299    {
1300        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1301        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1302
1303        if (request.tensor.data + request.offset < p0
1304         || request.tensor.data + request.offset >= p1
1305         || size > (p1 - request.tensor.data - request.offset)) {
1306            GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1307                           __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
1308            return false;
1309        }
1310    }
1311    ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
1312    response.result = 1;
1313    return true;
1314}
1315
1316bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1317    struct ggml_init_params params {
1318        /*.mem_size   =*/ ggml_tensor_overhead(),
1319        /*.mem_buffer =*/ NULL,
1320        /*.no_alloc   =*/ true,
1321    };
1322    ggml_context_ptr ctx_ptr { ggml_init(params) };
1323    GGML_ASSERT(ctx_ptr != nullptr);
1324    ggml_context * ctx = ctx_ptr.get();
1325    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1326    if (tensor == nullptr) {
1327        GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1328        return false;
1329    }
1330    LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
1331    // Call the backend's buffer_init_tensor function
1332    ggml_backend_buffer_t buffer = tensor->buffer;
1333    if (buffer && buffer->iface.init_tensor) {
1334        buffer->iface.init_tensor(buffer, tensor);
1335    } else {
1336        GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
1337    }
1338
1339    if (tensor->extra != nullptr) {
1340        // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1341        // Currently unimplemented.
1342        GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1343        return false;
1344    }
1345
1346    return true;
1347}
1348
1349bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
1350    struct ggml_init_params params {
1351        /*.mem_size   =*/ ggml_tensor_overhead(),
1352        /*.mem_buffer =*/ NULL,
1353        /*.no_alloc   =*/ true,
1354    };
1355    ggml_context_ptr ctx_ptr { ggml_init(params) };
1356    GGML_ASSERT(ctx_ptr != nullptr);
1357    ggml_context * ctx = ctx_ptr.get();
1358    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1359    if (tensor == nullptr || tensor->buffer == nullptr) {
1360        GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1361        return false;
1362    }
1363    LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
1364
1365    // sanitize tensor->data
1366    {
1367        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1368        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1369
1370        if (request.tensor.data + request.offset < p0 ||
1371            request.tensor.data + request.offset >= p1 ||
1372            request.size > (p1 - request.tensor.data - request.offset)) {
1373                GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1374                               __func__, request.tensor.data, request.offset, request.size, p0, p1);
1375                return false;
1376        }
1377    }
1378
1379    response.resize(request.size, 0);
1380    ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
1381    return true;
1382}
1383
1384bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
1385    struct ggml_init_params params {
1386        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
1387        /*.mem_buffer =*/ NULL,
1388        /*.no_alloc   =*/ true,
1389    };
1390    ggml_context_ptr ctx_ptr { ggml_init(params) };
1391    GGML_ASSERT(ctx_ptr != nullptr);
1392    ggml_context * ctx = ctx_ptr.get();
1393
1394    ggml_tensor * src = deserialize_tensor(ctx, &request.src);
1395    ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1396    if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
1397        GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1398        return false;
1399    }
1400
1401    uint64_t src_size   = (uint64_t) ggml_nbytes(src);
1402    uint64_t dst_data   = (uint64_t) dst->data;
1403    uint64_t dst_base   = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1404    uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1405
1406    if (dst_data + src_size > dst_base + dst_buf_sz) {
1407        GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1408                         "    write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1409                         "    buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1410                         __func__,
1411                         dst_data,
1412                         dst_data + src_size,
1413                         dst_base,
1414                         dst_base + dst_buf_sz);
1415        return false;
1416    }
1417
1418    LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
1419            __func__, (void*) src->buffer, (void*) dst->buffer);
1420
1421    response.result = ggml_backend_buffer_copy_tensor(src, dst);
1422    return true;
1423}
1424
1425ggml_tensor * rpc_server::create_node(uint64_t id,
1426                                      struct ggml_context * ctx,
1427                                      const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1428                                      std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
1429    if (tensor_map.find(id) != tensor_map.end()) {
1430        return tensor_map[id];
1431    }
1432    // Safely find the tensor pointer
1433    auto it_ptr = tensor_ptrs.find(id);
1434    if (it_ptr == tensor_ptrs.end()) {
1435        return nullptr;
1436    }
1437    const rpc_tensor * tensor = it_ptr->second;
1438
1439    struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
1440    if (result == nullptr) {
1441        return nullptr;
1442    }
1443    tensor_map[id] = result;
1444    for (int i = 0; i < GGML_MAX_SRC; i++) {
1445        // Check if the source ID is 0 before calling create_node recursively
1446        if (tensor->src[i] == 0) {
1447            result->src[i] = nullptr;
1448        } else {
1449            result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1450            // If the recursive call failed for a non-zero ID, propagate the error
1451            if (result->src[i] == nullptr) {
1452                GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1453                               __func__, i, tensor->src[i], id);
1454                // Must return nullptr to signal failure up the call stack
1455                return nullptr;
1456            }
1457        }
1458    }
1459
1460    // Handle view_src similarly
1461    if (tensor->view_src == 0) {
1462        result->view_src = nullptr;
1463    } else {
1464        result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1465        // If the recursive call failed for a non-zero ID, propagate the error
1466        if (result->view_src == nullptr) {
1467            GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1468                           __func__, tensor->view_src, id);
1469            // Must return nullptr to signal failure up the call stack
1470            return nullptr;
1471        }
1472    }
1473    result->view_offs = tensor->view_offs;
1474    return result;
1475}
1476
1477bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
1478    // serialization format:
1479    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1480    if (input.size() < 2*sizeof(uint32_t)) {
1481        return false;
1482    }
1483    const uint8_t * src = input.data();
1484    uint32_t device;
1485    memcpy(&device, src, sizeof(device));
1486    src += sizeof(device);
1487    if (device >= backends.size()) {
1488        return false;
1489    }
1490    uint32_t n_nodes;
1491    memcpy(&n_nodes, src, sizeof(n_nodes));
1492    src += sizeof(n_nodes);
1493    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1494        return false;
1495    }
1496    const uint64_t * nodes = (const uint64_t *)src;
1497    src += n_nodes*sizeof(uint64_t);
1498    uint32_t n_tensors;
1499    memcpy(&n_tensors, src, sizeof(n_tensors));
1500    src += sizeof(n_tensors);
1501    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1502        return false;
1503    }
1504    const rpc_tensor * tensors = (const rpc_tensor *)src;
1505    LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
1506
1507    size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1508
1509    struct ggml_init_params params = {
1510        /*.mem_size   =*/ buf_size,
1511        /*.mem_buffer =*/ NULL,
1512        /*.no_alloc   =*/ true,
1513    };
1514    ggml_context_ptr ctx_ptr { ggml_init(params) };
1515    GGML_ASSERT(ctx_ptr != nullptr);
1516    ggml_context * ctx = ctx_ptr.get();
1517    struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1518    graph->n_nodes = n_nodes;
1519    std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1520    tensor_ptrs.reserve(n_tensors);
1521    for (uint32_t i = 0; i < n_tensors; i++) {
1522        tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
1523    }
1524    std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1525    tensor_map.reserve(n_nodes);
1526    for (uint32_t i = 0; i < n_nodes; i++) {
1527        int64_t id;
1528        memcpy(&id, &nodes[i], sizeof(id));
1529        graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1530
1531        // Check if create_node failed for a *non-zero* ID.
1532        // If id was 0, create_node returning nullptr is expected.
1533        // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1534        if (graph->nodes[i] == nullptr && id != 0) {
1535            GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1536            return false;
1537        }
1538    }
1539    ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1540    GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1541    stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1542    stored_graphs[device].graph = graph;
1543    return true;
1544}
1545
1546bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1547    uint32_t device = request.device;
1548    if (device >= backends.size()) {
1549        return false;
1550    }
1551    if (stored_graphs[device].graph == nullptr) {
1552        return false;
1553    }
1554    ggml_cgraph * graph = stored_graphs[device].graph;
1555    LOG_DBG("[%s] device: %u\n", __func__, device);
1556    ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1557    GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1558    return true;
1559}
1560
1561bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1562    uint32_t dev_id = request.device;
1563    if (dev_id >= backends.size()) {
1564        return false;
1565    }
1566    size_t free, total;
1567    ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1568    ggml_backend_dev_memory(dev, &free, &total);
1569    response.free_mem = free;
1570    response.total_mem = total;
1571    LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1572    return true;
1573}
1574
1575rpc_server::~rpc_server() {
1576    for (auto buffer : buffers) {
1577        ggml_backend_buffer_free(buffer);
1578    }
1579}
1580
1581static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1582                             sockfd_t sockfd) {
1583    rpc_server server(backends, cache_dir);
1584    uint8_t cmd;
1585    if (!recv_data(sockfd, &cmd, 1)) {
1586        return;
1587    }
1588    // the first command sent by the client must be HELLO
1589    if (cmd != RPC_CMD_HELLO) {
1590        GGML_LOG_ERROR("Expected HELLO command, update client\n");
1591        return;
1592    }
1593    if (!recv_msg(sockfd, nullptr, 0)) {
1594        return;
1595    }
1596    rpc_msg_hello_rsp response;
1597    server.hello(response);
1598    if (!send_msg(sockfd, &response, sizeof(response))) {
1599        return;
1600    }
1601    while (true) {
1602        if (!recv_data(sockfd, &cmd, 1)) {
1603            break;
1604        }
1605        if (cmd >= RPC_CMD_COUNT) {
1606            // fail fast if the command is invalid
1607            GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1608            break;
1609        }
1610        switch (cmd) {
1611            case RPC_CMD_HELLO: {
1612                // HELLO command is handled above
1613                return;
1614            }
1615            case RPC_CMD_DEVICE_COUNT: {
1616                if (!recv_msg(sockfd, nullptr, 0)) {
1617                    return;
1618                }
1619                rpc_msg_device_count_rsp response;
1620                response.device_count = backends.size();
1621                if (!send_msg(sockfd, &response, sizeof(response))) {
1622                    return;
1623                }
1624                break;
1625            }
1626            case RPC_CMD_ALLOC_BUFFER: {
1627                rpc_msg_alloc_buffer_req request;
1628                if (!recv_msg(sockfd, &request, sizeof(request))) {
1629                    return;
1630                }
1631                rpc_msg_alloc_buffer_rsp response;
1632                if (!server.alloc_buffer(request, response)) {
1633                    return;
1634                }
1635                if (!send_msg(sockfd, &response, sizeof(response))) {
1636                    return;
1637                }
1638                break;
1639            }
1640            case RPC_CMD_GET_ALLOC_SIZE: {
1641                rpc_msg_get_alloc_size_req request;
1642                if (!recv_msg(sockfd, &request, sizeof(request))) {
1643                    return;
1644                }
1645                rpc_msg_get_alloc_size_rsp response;
1646                if (!server.get_alloc_size(request, response)) {
1647                    return;
1648                }
1649                if (!send_msg(sockfd, &response, sizeof(response))) {
1650                    return;
1651                }
1652                break;
1653            }
1654            case RPC_CMD_GET_ALIGNMENT: {
1655                rpc_msg_get_alignment_req request;
1656                if (!recv_msg(sockfd, &request, sizeof(request))) {
1657                    return;
1658                }
1659                rpc_msg_get_alignment_rsp response;
1660                if (!server.get_alignment(request, response)) {
1661                    return;
1662                }
1663                if (!send_msg(sockfd, &response, sizeof(response))) {
1664                    return;
1665                }
1666                break;
1667            }
1668            case RPC_CMD_GET_MAX_SIZE: {
1669                rpc_msg_get_max_size_req request;
1670                if (!recv_msg(sockfd, &request, sizeof(request))) {
1671                    return;
1672                }
1673                rpc_msg_get_max_size_rsp response;
1674                if (!server.get_max_size(request, response)) {
1675                    return;
1676                }
1677                if (!send_msg(sockfd, &response, sizeof(response))) {
1678                    return;
1679                }
1680                break;
1681            }
1682            case RPC_CMD_BUFFER_GET_BASE: {
1683                rpc_msg_buffer_get_base_req request;
1684                if (!recv_msg(sockfd, &request, sizeof(request))) {
1685                    return;
1686                }
1687                rpc_msg_buffer_get_base_rsp response;
1688                if (!server.buffer_get_base(request, response)) {
1689                    return;
1690                }
1691                if (!send_msg(sockfd, &response, sizeof(response))) {
1692                    return;
1693                }
1694                break;
1695            }
1696            case RPC_CMD_FREE_BUFFER: {
1697                rpc_msg_free_buffer_req request;
1698                if (!recv_msg(sockfd, &request, sizeof(request))) {
1699                    return;
1700                }
1701                if (!server.free_buffer(request)) {
1702                    return;
1703                }
1704                if (!send_msg(sockfd, nullptr, 0)) {
1705                    return;
1706                }
1707                break;
1708            }
1709            case RPC_CMD_BUFFER_CLEAR: {
1710                rpc_msg_buffer_clear_req request;
1711                if (!recv_msg(sockfd, &request, sizeof(request))) {
1712                    return;
1713                }
1714                if (!server.buffer_clear(request)) {
1715                    return;
1716                }
1717                if (!send_msg(sockfd, nullptr, 0)) {
1718                    return;
1719                }
1720                break;
1721            }
1722            case RPC_CMD_SET_TENSOR: {
1723                std::vector<uint8_t> input;
1724                if (!recv_msg(sockfd, input)) {
1725                    return;
1726                }
1727                if (!server.set_tensor(input)) {
1728                    return;
1729                }
1730                break;
1731            }
1732            case RPC_CMD_SET_TENSOR_HASH: {
1733                rpc_msg_set_tensor_hash_req request;
1734                if (!recv_msg(sockfd, &request, sizeof(request))) {
1735                    return;
1736                }
1737                rpc_msg_set_tensor_hash_rsp response;
1738                if (!server.set_tensor_hash(request, response)) {
1739                    return;
1740                }
1741                if (!send_msg(sockfd, &response, sizeof(response))) {
1742                    return;
1743                }
1744                break;
1745            }
1746            case RPC_CMD_INIT_TENSOR: {
1747                rpc_msg_init_tensor_req request;
1748                if (!recv_msg(sockfd, &request,sizeof(request))) {
1749                    return;
1750                }
1751                if (!server.init_tensor(request)) {
1752                    return;
1753                }
1754                if (!send_msg(sockfd, nullptr, 0)) {
1755                    return;
1756                }
1757                break;
1758            }
1759            case RPC_CMD_GET_TENSOR: {
1760                rpc_msg_get_tensor_req request;
1761                if (!recv_msg(sockfd, &request, sizeof(request))) {
1762                    return;
1763                }
1764                std::vector<uint8_t> response;
1765                if (!server.get_tensor(request, response)) {
1766                    return;
1767                }
1768                if (!send_msg(sockfd, response.data(), response.size())) {
1769                    return;
1770                }
1771                break;
1772            }
1773            case RPC_CMD_COPY_TENSOR: {
1774                rpc_msg_copy_tensor_req request;
1775                if (!recv_msg(sockfd, &request, sizeof(request))) {
1776                    return;
1777                }
1778                rpc_msg_copy_tensor_rsp response;
1779                if (!server.copy_tensor(request, response)) {
1780                    return;
1781                }
1782                if (!send_msg(sockfd, &response, sizeof(response))) {
1783                    return;
1784                }
1785                break;
1786            }
1787            case RPC_CMD_GRAPH_COMPUTE: {
1788                std::vector<uint8_t> input;
1789                if (!recv_msg(sockfd, input)) {
1790                    return;
1791                }
1792                if (!server.graph_compute(input)) {
1793                    return;
1794                }
1795                break;
1796            }
1797            case RPC_CMD_GRAPH_RECOMPUTE: {
1798                rpc_msg_graph_recompute_req request;
1799                if (!recv_msg(sockfd, &request, sizeof(request))) {
1800                    return;
1801                }
1802                if (!server.graph_recompute(request)) {
1803                    return;
1804                }
1805                break;
1806            }
1807            case RPC_CMD_GET_DEVICE_MEMORY: {
1808                rpc_msg_get_device_memory_req request;
1809                if (!recv_msg(sockfd, &request, sizeof(request))) {
1810                    return;
1811                }
1812                rpc_msg_get_device_memory_rsp response;
1813                if (!server.get_device_memory(request, response)) {
1814                    return;
1815                }
1816                if (!send_msg(sockfd, &response, sizeof(response))) {
1817                    return;
1818                }
1819                break;
1820            }
1821            default: {
1822                GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1823                return;
1824            }
1825        }
1826    }
1827}
1828
1829void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1830                                   size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1831    if (n_devices == 0 || devices == nullptr) {
1832        fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
1833        return;
1834    }
1835    std::vector<ggml_backend_t> backends;
1836    printf("Starting RPC server v%d.%d.%d\n",
1837        RPC_PROTO_MAJOR_VERSION,
1838        RPC_PROTO_MINOR_VERSION,
1839        RPC_PROTO_PATCH_VERSION);
1840    printf("  endpoint       : %s\n", endpoint);
1841    printf("  local cache    : %s\n", cache_dir ? cache_dir : "n/a");
1842    printf("Devices:\n");
1843    for (size_t i = 0; i < n_devices; i++) {
1844        auto dev = devices[i];
1845        size_t free, total;
1846        ggml_backend_dev_memory(dev, &free, &total);
1847        printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1848               total / 1024 / 1024, free / 1024 / 1024);
1849        auto backend = ggml_backend_dev_init(dev, nullptr);
1850        if (!backend) {
1851            fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
1852            return;
1853        }
1854        backends.push_back(backend);
1855        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1856        if (reg) {
1857            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1858            if (ggml_backend_set_n_threads_fn) {
1859                ggml_backend_set_n_threads_fn(backend, n_threads);
1860            }
1861        }
1862    }
1863
1864    std::string host;
1865    int port;
1866    if (!parse_endpoint(endpoint, host, port)) {
1867        return;
1868    }
1869#ifdef _WIN32
1870    {
1871        WSADATA wsaData;
1872        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
1873        if (res != 0) {
1874            fprintf(stderr, "WSAStartup failed: %d\n", res);
1875            return;
1876        }
1877    }
1878#endif
1879    auto server_socket = create_server_socket(host.c_str(), port);
1880    if (server_socket == nullptr) {
1881        fprintf(stderr, "Failed to create server socket\n");
1882        return;
1883    }
1884    while (true) {
1885        auto client_socket = socket_accept(server_socket->fd);
1886        if (client_socket == nullptr) {
1887            fprintf(stderr, "Failed to accept client connection\n");
1888            return;
1889        }
1890        printf("Accepted client connection\n");
1891        fflush(stdout);
1892        rpc_serve_client(backends, cache_dir, client_socket->fd);
1893        printf("Client connection closed\n");
1894        fflush(stdout);
1895    }
1896#ifdef _WIN32
1897    WSACleanup();
1898#endif
1899    for (auto backend : backends) {
1900        ggml_backend_free(backend);
1901    }
1902}
1903
1904// device interface
1905
1906struct ggml_backend_rpc_device_context {
1907    std::string endpoint;
1908    uint32_t    device;
1909    std::string name;
1910    std::string description;
1911};
1912
1913static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1914    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1915
1916    return ctx->name.c_str();
1917}
1918
1919static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1920    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1921
1922    return ctx->description.c_str();
1923}
1924
1925static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1926    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1927
1928    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
1929}
1930
1931static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1932    // TODO: obtain value from the server
1933    return GGML_BACKEND_DEVICE_TYPE_GPU;
1934
1935    GGML_UNUSED(dev);
1936}
1937
1938static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1939    props->name        = ggml_backend_rpc_device_get_name(dev);
1940    props->description = ggml_backend_rpc_device_get_description(dev);
1941    props->type        = ggml_backend_rpc_device_get_type(dev);
1942    ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1943    props->caps = {
1944        /* .async                 = */ false,
1945        /* .host_buffer           = */ false,
1946        /* .buffer_from_host_ptr  = */ false,
1947        /* .events                = */ false,
1948    };
1949}
1950
1951static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1952    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1953
1954    return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
1955
1956    GGML_UNUSED(params);
1957}
1958
1959static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1960    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1961
1962    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
1963
1964    GGML_UNUSED(dev);
1965}
1966
1967static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1968    GGML_UNUSED(dev);
1969    GGML_UNUSED(op);
1970    //TODO: call the remote backend and cache the results
1971    return true;
1972}
1973
1974static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1975    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1976        return false;
1977    }
1978    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1979    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1980    return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
1981}
1982
1983static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1984    /* .get_name             = */ ggml_backend_rpc_device_get_name,
1985    /* .get_description      = */ ggml_backend_rpc_device_get_description,
1986    /* .get_memory           = */ ggml_backend_rpc_device_get_memory,
1987    /* .get_type             = */ ggml_backend_rpc_device_get_type,
1988    /* .get_props            = */ ggml_backend_rpc_device_get_props,
1989    /* .init_backend         = */ ggml_backend_rpc_device_init,
1990    /* .get_buffer_type      = */ ggml_backend_rpc_device_get_buffer_type,
1991    /* .get_host_buffer_type = */ NULL,
1992    /* .buffer_from_host_ptr = */ NULL,
1993    /* .supports_op          = */ ggml_backend_rpc_device_supports_op,
1994    /* .supports_buft        = */ ggml_backend_rpc_device_supports_buft,
1995    /* .offload_op           = */ NULL,
1996    /* .event_new            = */ NULL,
1997    /* .event_free           = */ NULL,
1998    /* .event_synchronize    = */ NULL,
1999};
2000
2001// backend reg interface
2002
2003struct ggml_backend_rpc_reg_context {
2004    std::string                     name;
2005    std::vector<ggml_backend_dev_t> devices;
2006};
2007
2008static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
2009    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2010    return ctx ? ctx->name.c_str() : "RPC";
2011}
2012
2013static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
2014    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2015    return ctx ? ctx->devices.size() : 0;
2016}
2017
2018static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2019    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2020    if (ctx == nullptr) {
2021        GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
2022    } else {
2023        GGML_ASSERT(index < ctx->devices.size());
2024        return ctx->devices[index];
2025    }
2026}
2027
2028static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2029    if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
2030        return (void *)ggml_backend_rpc_add_server;
2031    }
2032    if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
2033        return (void *)ggml_backend_rpc_start_server;
2034    }
2035    return NULL;
2036
2037    GGML_UNUSED(reg);
2038}
2039
2040static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
2041    /* .get_name         = */ ggml_backend_rpc_reg_get_name,
2042    /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2043    /* .get_device       = */ ggml_backend_rpc_reg_get_device,
2044    /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2045};
2046
2047ggml_backend_reg_t ggml_backend_rpc_reg(void) {
2048    static struct ggml_backend_reg ggml_backend_rpc_reg = {
2049        /* .api_version = */ GGML_BACKEND_API_VERSION,
2050        /* .iface       = */ ggml_backend_rpc_reg_i,
2051        /* .context     = */ NULL,
2052    };
2053
2054    return &ggml_backend_rpc_reg;
2055}
2056
2057static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
2058    auto sock = get_socket(endpoint);
2059    if (sock == nullptr) {
2060        GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
2061        return 0;
2062    }
2063    rpc_msg_device_count_rsp response;
2064    bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
2065    RPC_STATUS_ASSERT(status);
2066    return response.device_count;
2067}
2068
2069static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
2070    /* .get_name          = */ ggml_backend_rpc_reg_get_name,
2071    /* .get_device_count  = */ ggml_backend_rpc_reg_get_device_count,
2072    /* .get_device        = */ ggml_backend_rpc_reg_get_device,
2073    /* .get_proc_address  = */ ggml_backend_rpc_get_proc_address,
2074};
2075
2076ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
2077    static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
2078    static std::mutex mutex;
2079    static uint32_t dev_id = 0;
2080    std::lock_guard<std::mutex> lock(mutex);
2081    if (reg_map.find(endpoint) != reg_map.end()) {
2082        return reg_map[endpoint];
2083    }
2084    uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
2085    if (dev_count == 0) {
2086        return nullptr;
2087    }
2088    ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
2089    ctx->name = "RPC[" + std::string(endpoint) + "]";
2090    for (uint32_t ind = 0; ind < dev_count; ind++) {
2091        std::string dev_name = "RPC" + std::to_string(dev_id);
2092        std::string dev_desc = std::string(endpoint);
2093        ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
2094            /* .endpoint    = */ endpoint,
2095            /* .device      = */ ind,
2096            /* .name        = */ dev_name,
2097            /* .description = */ dev_desc
2098        };
2099
2100        ggml_backend_dev_t dev = new ggml_backend_device {
2101            /* .iface   = */ ggml_backend_rpc_device_i,
2102            /* .reg     = */ ggml_backend_rpc_reg(),
2103            /* .context = */ dev_ctx,
2104        };
2105        ctx->devices.push_back(dev);
2106        dev_id++;
2107    }
2108    ggml_backend_reg_t reg = new ggml_backend_reg {
2109        /* .api_version = */ GGML_BACKEND_API_VERSION,
2110        /* .iface       = */ ggml_backend_rpc_reg_interface,
2111        /* .context     = */ ctx
2112    };
2113    reg_map[endpoint] = reg;
2114    return reg;
2115}
2116
2117
2118GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)