1#include "ggml-rpc.h"
2#include "ggml-impl.h"
3#include "ggml-backend-impl.h"
4#include "ggml-cpp.h"
5
6#include <cinttypes>
7#include <string>
8#include <vector>
9#include <memory>
10#include <mutex>
11#include <unordered_map>
12#include <unordered_set>
13#ifdef _WIN32
14# define WIN32_LEAN_AND_MEAN
15# ifndef NOMINMAX
16# define NOMINMAX
17# endif
18# include <windows.h>
19# include <winsock2.h>
20#else
21# include <arpa/inet.h>
22# include <sys/socket.h>
23# include <sys/types.h>
24# include <netinet/in.h>
25# include <netinet/tcp.h>
26# include <netdb.h>
27# include <unistd.h>
28#endif
29#include <cstring>
30#include <fstream>
31#include <filesystem>
32#include <algorithm>
33
34static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
35
36#define LOG_DBG(...) \
37 do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
38
39
40namespace fs = std::filesystem;
41
42static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
43
44#ifdef _WIN32
45typedef SOCKET sockfd_t;
46using ssize_t = __int64;
47#else
48typedef int sockfd_t;
49#endif
50
51// cross-platform socket
52struct socket_t {
53 sockfd_t fd;
54 socket_t(sockfd_t fd) : fd(fd) {}
55 ~socket_t() {
56 LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
57#ifdef _WIN32
58 closesocket(this->fd);
59#else
60 close(this->fd);
61#endif
62 }
63};
64
65// macro for nicer error messages on server crash
66#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
67
68// all RPC structures must be packed
69#pragma pack(push, 1)
70// ggml_tensor is serialized into rpc_tensor
71struct rpc_tensor {
72 uint64_t id;
73 uint32_t type;
74 uint64_t buffer;
75 uint32_t ne[GGML_MAX_DIMS];
76 uint32_t nb[GGML_MAX_DIMS];
77 uint32_t op;
78 int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
79 int32_t flags;
80 uint64_t src[GGML_MAX_SRC];
81 uint64_t view_src;
82 uint64_t view_offs;
83 uint64_t data;
84 char name[GGML_MAX_NAME];
85
86 char padding[4];
87};
88
89static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
90
91// RPC commands
92enum rpc_cmd {
93 RPC_CMD_ALLOC_BUFFER = 0,
94 RPC_CMD_GET_ALIGNMENT,
95 RPC_CMD_GET_MAX_SIZE,
96 RPC_CMD_BUFFER_GET_BASE,
97 RPC_CMD_FREE_BUFFER,
98 RPC_CMD_BUFFER_CLEAR,
99 RPC_CMD_SET_TENSOR,
100 RPC_CMD_SET_TENSOR_HASH,
101 RPC_CMD_GET_TENSOR,
102 RPC_CMD_COPY_TENSOR,
103 RPC_CMD_GRAPH_COMPUTE,
104 RPC_CMD_GET_DEVICE_MEMORY,
105 RPC_CMD_INIT_TENSOR,
106 RPC_CMD_GET_ALLOC_SIZE,
107 RPC_CMD_HELLO,
108 RPC_CMD_DEVICE_COUNT,
109 RPC_CMD_GRAPH_RECOMPUTE,
110 RPC_CMD_COUNT,
111};
112
113static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
114
115// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
116const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
117
118struct rpc_msg_hello_rsp {
119 uint8_t major;
120 uint8_t minor;
121 uint8_t patch;
122};
123
124struct rpc_msg_device_count_rsp {
125 uint32_t device_count;
126};
127
128struct rpc_msg_get_alloc_size_req {
129 uint32_t device;
130 rpc_tensor tensor;
131 rpc_tensor srcs[GGML_MAX_SRC];
132};
133
134struct rpc_msg_get_alloc_size_rsp {
135 uint64_t alloc_size;
136};
137
138struct rpc_msg_init_tensor_req {
139 rpc_tensor tensor;
140};
141
142struct rpc_msg_alloc_buffer_req {
143 uint32_t device;
144 uint64_t size;
145};
146
147struct rpc_msg_alloc_buffer_rsp {
148 uint64_t remote_ptr;
149 uint64_t remote_size;
150};
151
152struct rpc_msg_get_alignment_req {
153 uint32_t device;
154};
155
156struct rpc_msg_get_alignment_rsp {
157 uint64_t alignment;
158};
159
160struct rpc_msg_get_max_size_req {
161 uint32_t device;
162};
163
164struct rpc_msg_get_max_size_rsp {
165 uint64_t max_size;
166};
167
168struct rpc_msg_buffer_get_base_req {
169 uint64_t remote_ptr;
170};
171
172struct rpc_msg_buffer_get_base_rsp {
173 uint64_t base_ptr;
174};
175
176struct rpc_msg_free_buffer_req {
177 uint64_t remote_ptr;
178};
179
180struct rpc_msg_buffer_clear_req {
181 uint64_t remote_ptr;
182 uint8_t value;
183};
184
185struct rpc_msg_set_tensor_hash_req {
186 rpc_tensor tensor;
187 uint64_t offset;
188 uint64_t hash;
189};
190
191struct rpc_msg_set_tensor_hash_rsp {
192 uint8_t result;
193};
194
195struct rpc_msg_get_tensor_req {
196 rpc_tensor tensor;
197 uint64_t offset;
198 uint64_t size;
199};
200
201struct rpc_msg_copy_tensor_req {
202 rpc_tensor src;
203 rpc_tensor dst;
204};
205
206struct rpc_msg_copy_tensor_rsp {
207 uint8_t result;
208};
209
210struct rpc_msg_get_device_memory_req {
211 uint32_t device;
212};
213
214struct rpc_msg_get_device_memory_rsp {
215 uint64_t free_mem;
216 uint64_t total_mem;
217};
218
219struct rpc_msg_graph_recompute_req {
220 uint32_t device;
221};
222
223#pragma pack(pop)
224
225// RPC data structures
226
227static ggml_guid_t ggml_backend_rpc_guid() {
228 static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
229 return &guid;
230}
231
232struct ggml_backend_rpc_buffer_type_context {
233 std::string endpoint;
234 uint32_t device;
235 std::string name;
236 size_t alignment;
237 size_t max_size;
238};
239
240struct graph_cache {
241
242 bool is_cached(const ggml_cgraph * cgraph) {
243 if ((int)last_graph.size() != cgraph->n_nodes) {
244 return false;
245 }
246 for (int i = 0; i < cgraph->n_nodes; i++) {
247 if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
248 return false;
249 }
250 }
251 return true;
252 }
253
254 void add(const ggml_cgraph * cgraph) {
255 last_graph.resize(cgraph->n_nodes);
256 for (int i = 0; i < cgraph->n_nodes; i++) {
257 memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
258 }
259 }
260
261 std::vector<ggml_tensor> last_graph;
262};
263
264struct ggml_backend_rpc_context {
265 std::string endpoint;
266 uint32_t device;
267 std::string name;
268 graph_cache gc;
269};
270
271struct ggml_backend_rpc_buffer_context {
272 std::shared_ptr<socket_t> sock;
273 void * base_ptr;
274 uint64_t remote_ptr;
275};
276
277// RPC helper functions
278
279// Computes FNV-1a hash of the data
280static uint64_t fnv_hash(const uint8_t * data, size_t len) {
281 const uint64_t fnv_prime = 0x100000001b3ULL;
282 uint64_t hash = 0xcbf29ce484222325ULL;
283
284 for (size_t i = 0; i < len; ++i) {
285 hash ^= data[i];
286 hash *= fnv_prime;
287 }
288 return hash;
289}
290
291static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
292#ifdef _WIN32
293 if (fd == INVALID_SOCKET) {
294 return nullptr;
295 }
296#else
297 if (fd < 0) {
298 return nullptr;
299 }
300#endif
301 return std::make_shared<socket_t>(fd);
302}
303
304static bool set_no_delay(sockfd_t sockfd) {
305 int flag = 1;
306 // set TCP_NODELAY to disable Nagle's algorithm
307 int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
308 return ret == 0;
309}
310
311static bool set_reuse_addr(sockfd_t sockfd) {
312 int flag = 1;
313 int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
314 return ret == 0;
315}
316
317static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
318 struct sockaddr_in addr;
319 auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
320 auto sock_ptr = make_socket(sockfd);
321 if (sock_ptr == nullptr) {
322 return nullptr;
323 }
324 if (!set_no_delay(sockfd)) {
325 GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
326 return nullptr;
327 }
328 addr.sin_family = AF_INET;
329 addr.sin_port = htons(port);
330 struct hostent * server = gethostbyname(host);
331 if (server == NULL) {
332 GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
333 return nullptr;
334 }
335 memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
336 if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
337 return nullptr;
338 }
339 return sock_ptr;
340}
341
342static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
343 auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
344 auto client_socket = make_socket(client_socket_fd);
345 if (client_socket == nullptr) {
346 return nullptr;
347 }
348 if (!set_no_delay(client_socket_fd)) {
349 GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
350 return nullptr;
351 }
352 return client_socket;
353}
354
355static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
356 auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
357 auto sock = make_socket(sockfd);
358 if (sock == nullptr) {
359 return nullptr;
360 }
361 if (!set_reuse_addr(sockfd)) {
362 GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
363 return nullptr;
364 }
365 if (inet_addr(host) == INADDR_NONE) {
366 GGML_LOG_ERROR("Invalid host address: %s\n", host);
367 return nullptr;
368 }
369 struct sockaddr_in serv_addr;
370 serv_addr.sin_family = AF_INET;
371 serv_addr.sin_addr.s_addr = inet_addr(host);
372 serv_addr.sin_port = htons(port);
373
374 if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
375 return nullptr;
376 }
377 if (listen(sockfd, 1) < 0) {
378 return nullptr;
379 }
380 return sock;
381}
382
383static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
384 size_t bytes_sent = 0;
385 while (bytes_sent < size) {
386 size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
387 ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
388 if (n < 0) {
389 GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
390 bytes_sent, size_to_send);
391 return false;
392 }
393 bytes_sent += (size_t)n;
394 }
395 return true;
396}
397
398static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
399 size_t bytes_recv = 0;
400 while (bytes_recv < size) {
401 size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
402 ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
403 if (n < 0) {
404 GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
405 bytes_recv, size_to_recv);
406 return false;
407 }
408 if (n == 0) {
409 LOG_DBG("recv returned 0 (peer closed?)\n");
410 return false;
411 }
412 bytes_recv += (size_t)n;
413 }
414 return true;
415}
416
417static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
418 if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
419 return false;
420 }
421 return send_data(sockfd, msg, msg_size);
422}
423
424static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
425 uint64_t size;
426 if (!recv_data(sockfd, &size, sizeof(size))) {
427 return false;
428 }
429 if (size != msg_size) {
430 return false;
431 }
432 return recv_data(sockfd, msg, msg_size);
433}
434
435static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
436 uint64_t size;
437 if (!recv_data(sockfd, &size, sizeof(size))) {
438 return false;
439 }
440 try {
441 input.resize(size);
442 } catch (const std::bad_alloc & e) {
443 GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
444 return false;
445 }
446 return recv_data(sockfd, input.data(), size);
447}
448
449static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
450 size_t pos = endpoint.find(':');
451 if (pos == std::string::npos) {
452 return false;
453 }
454 host = endpoint.substr(0, pos);
455 port = std::stoi(endpoint.substr(pos + 1));
456 return true;
457}
458
459// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
460// No response
461static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
462 uint8_t cmd_byte = cmd;
463 if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
464 return false;
465 }
466 if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
467 return false;
468 }
469 if (!send_data(sock->fd, input, input_size)) {
470 return false;
471 }
472 return true;
473}
474
475// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
476// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
477static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
478 if (!send_rpc_cmd(sock, cmd, input, input_size)) {
479 return false;
480 }
481 // TODO: currently the output_size is always known, do we need support for commands with variable output size?
482 // even if we do, we can skip sending output_size from the server for commands with known output size
483 uint64_t out_size;
484 if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
485 return false;
486 }
487 if (out_size != output_size) {
488 return false;
489 }
490 if (!recv_data(sock->fd, output, output_size)) {
491 return false;
492 }
493 return true;
494}
495
496// RPC client-side implementation
497
498static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
499 rpc_msg_hello_rsp response;
500 bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
501 RPC_STATUS_ASSERT(status);
502 if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
503 GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
504 return false;
505 }
506 if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
507 GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
508 }
509 return true;
510}
511
512static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
513 static std::mutex mutex;
514 std::lock_guard<std::mutex> lock(mutex);
515 static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
516 static bool initialized = false;
517
518 auto it = sockets.find(endpoint);
519 if (it != sockets.end()) {
520 if (auto sock = it->second.lock()) {
521 return sock;
522 }
523 }
524 std::string host;
525 int port;
526 if (!parse_endpoint(endpoint, host, port)) {
527 GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
528 return nullptr;
529 }
530#ifdef _WIN32
531 if (!initialized) {
532 WSADATA wsaData;
533 int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
534 if (res != 0) {
535 return nullptr;
536 }
537 initialized = true;
538 }
539#else
540 GGML_UNUSED(initialized);
541#endif
542 auto sock = socket_connect(host.c_str(), port);
543 if (sock == nullptr) {
544 return nullptr;
545 }
546 if (!check_server_version(sock)) {
547 return nullptr;
548 }
549 LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
550 sockets[endpoint] = sock;
551 return sock;
552}
553
554static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
555 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
556 rpc_msg_free_buffer_req request = {ctx->remote_ptr};
557 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
558 RPC_STATUS_ASSERT(status);
559 delete ctx;
560}
561
562static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
563 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
564 if (ctx->base_ptr != nullptr) {
565 return ctx->base_ptr;
566 }
567 rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
568 rpc_msg_buffer_get_base_rsp response;
569 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
570 RPC_STATUS_ASSERT(status);
571 ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
572 return ctx->base_ptr;
573}
574
575static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
576 return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
577}
578
579static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
580 rpc_tensor result;
581 if (!tensor) {
582 memset(&result, 0, sizeof(result));
583 return result;
584 }
585
586 result.id = reinterpret_cast<uint64_t>(tensor);
587 result.type = tensor->type;
588 if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
589 ggml_backend_buffer_t buffer = tensor->buffer;
590 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
591 result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
592 } else {
593 result.buffer = 0;
594 }
595 for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
596 result.ne[i] = tensor->ne[i];
597 result.nb[i] = tensor->nb[i];
598 }
599 result.op = tensor->op;
600 for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
601 result.op_params[i] = tensor->op_params[i];
602 }
603 result.flags = tensor->flags;
604 for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
605 result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
606 }
607 result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
608 result.view_offs = tensor->view_offs;
609 result.data = reinterpret_cast<uint64_t>(tensor->data);
610
611 // Avoid sending uninitialized data over the wire
612 memset(result.name, 0, sizeof(result.name));
613 memset(result.padding, 0, sizeof(result.padding));
614
615 snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
616 return result;
617}
618
619static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
620 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
621
622 // CUDA backend on the server pads everything to 512 due to CUDA limitations.
623 // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
624 // In particular, only quantized tensors need padding
625 if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
626 rpc_msg_init_tensor_req request;
627
628 request.tensor = serialize_tensor(tensor);
629
630 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
631 RPC_STATUS_ASSERT(status);
632 }
633 return GGML_STATUS_SUCCESS;
634}
635
636static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
637 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
638 rpc_tensor rpc_tensor = serialize_tensor(tensor);
639 if (size > HASH_THRESHOLD) {
640 rpc_msg_set_tensor_hash_req request;
641 request.tensor = rpc_tensor;
642 request.offset = offset;
643 request.hash = fnv_hash((const uint8_t*)data, size);
644 rpc_msg_set_tensor_hash_rsp response;
645 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
646 RPC_STATUS_ASSERT(status);
647 if (response.result) {
648 // the server has the same data, no need to send it
649 return;
650 }
651 }
652 // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
653 size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
654 std::vector<uint8_t> input(input_size, 0);
655 memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
656 memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
657 memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
658 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
659 RPC_STATUS_ASSERT(status);
660}
661
662static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
663 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
664 rpc_msg_get_tensor_req request;
665 request.tensor = serialize_tensor(tensor);
666 request.offset = offset;
667 request.size = size;
668 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
669 RPC_STATUS_ASSERT(status);
670}
671
672static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
673 if (ggml_backend_buffer_is_rpc(src->buffer)) {
674 // check if src and dst are on the same server
675 ggml_backend_buffer_t src_buffer = src->buffer;
676 ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
677 ggml_backend_buffer_t dst_buffer = dst->buffer;
678 ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
679 if (src_ctx->sock != dst_ctx->sock) {
680 return false;
681 }
682 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
683 rpc_msg_copy_tensor_req request;
684 request.src = serialize_tensor(src);
685 request.dst = serialize_tensor(dst);
686 rpc_msg_copy_tensor_rsp response;
687 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
688 RPC_STATUS_ASSERT(status);
689 return response.result;
690 }
691 return false;
692}
693
694static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
695 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
696 rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
697 bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
698 RPC_STATUS_ASSERT(status);
699}
700
701static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
702 /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
703 /* .get_base = */ ggml_backend_rpc_buffer_get_base,
704 /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
705 /* .memset_tensor = */ NULL,
706 /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
707 /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
708 /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
709 /* .clear = */ ggml_backend_rpc_buffer_clear,
710 /* .reset = */ NULL,
711};
712
713static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
714 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
715 return buft_ctx->name.c_str();
716}
717
718static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
719 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
720 rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
721 rpc_msg_alloc_buffer_rsp response;
722 auto sock = get_socket(buft_ctx->endpoint);
723 bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
724 RPC_STATUS_ASSERT(status);
725 if (response.remote_ptr != 0) {
726 ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
727 ggml_backend_rpc_buffer_interface,
728 new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
729 response.remote_size);
730 return buffer;
731 } else {
732 return nullptr;
733 }
734}
735
736static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
737 rpc_msg_get_alignment_req request = {device};
738 rpc_msg_get_alignment_rsp response;
739 bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
740 RPC_STATUS_ASSERT(status);
741 return response.alignment;
742}
743
744static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
745 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
746 return buft_ctx->alignment;
747}
748
749static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
750 rpc_msg_get_max_size_req request = {device};
751 rpc_msg_get_max_size_rsp response;
752 bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
753 RPC_STATUS_ASSERT(status);
754 return response.max_size;
755}
756
757static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
758 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
759 return buft_ctx->max_size;
760}
761
762static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
763 // should we query the remote server for the actual size
764 bool rpc_get = false;
765
766 // See comments in init_tensor.
767 rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
768
769 // ops that require additional memory for fleeting data on certain backends
770 // ref: https://github.com/ggml-org/llama.cpp/pull/15966
771 rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
772 rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
773
774 if (rpc_get) {
775 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
776 auto sock = get_socket(buft_ctx->endpoint);
777
778 rpc_msg_get_alloc_size_req request = {
779 /*.device =*/ buft_ctx->device,
780 /*.tensor =*/ serialize_tensor(tensor),
781 /*.srcs =*/ {},
782 };
783
784 // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
785 for (int i = 0; i < GGML_MAX_SRC; i++) {
786 request.srcs[i] = serialize_tensor(tensor->src[i]);
787 }
788
789 // TODO: cache the alloc responses to avoid extra RPC calls?
790 rpc_msg_get_alloc_size_rsp response;
791 bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
792 RPC_STATUS_ASSERT(status);
793
794 return response.alloc_size;
795 }
796
797 return ggml_nbytes(tensor);
798}
799
800static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
801 /* .get_name = */ ggml_backend_rpc_buffer_type_name,
802 /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
803 /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
804 /* .get_max_size = */ ggml_backend_rpc_get_max_size,
805 /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
806 /* .is_host = */ NULL,
807};
808
809static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
810 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
811
812 return rpc_ctx->name.c_str();
813}
814
815static void ggml_backend_rpc_free(ggml_backend_t backend) {
816 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
817 delete rpc_ctx;
818 delete backend;
819}
820
821static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
822 GGML_UNUSED(backend);
823 // this is no-op because we don't have any async operations
824}
825
826static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
827 if (tensor == nullptr) {
828 return;
829 }
830 if (visited.find(tensor) != visited.end()) {
831 return;
832 }
833 visited.insert(tensor);
834 for (int i = 0; i < GGML_MAX_SRC; i++) {
835 add_tensor(tensor->src[i], tensors, visited);
836 }
837 add_tensor(tensor->view_src, tensors, visited);
838 tensors.push_back(serialize_tensor(tensor));
839}
840
841static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
842 uint32_t n_nodes = cgraph->n_nodes;
843 std::vector<rpc_tensor> tensors;
844 std::unordered_set<ggml_tensor*> visited;
845 for (uint32_t i = 0; i < n_nodes; i++) {
846 add_tensor(cgraph->nodes[i], tensors, visited);
847 }
848 // serialization format:
849 // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
850 uint32_t n_tensors = tensors.size();
851 int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
852 output.resize(output_size, 0);
853 uint8_t * dest = output.data();
854 memcpy(dest, &device, sizeof(device));
855 dest += sizeof(device);
856 memcpy(dest, &n_nodes, sizeof(n_nodes));
857 dest += sizeof(n_nodes);
858 for (uint32_t i = 0; i < n_nodes; i++) {
859 memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
860 }
861 dest += n_nodes * sizeof(uint64_t);
862 memcpy(dest, &n_tensors, sizeof(n_tensors));
863 dest += sizeof(n_tensors);
864 rpc_tensor * out_tensors = (rpc_tensor *)dest;
865 memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
866}
867
868static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
869 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
870
871 GGML_ASSERT(cgraph->n_nodes > 0);
872 bool reuse = rpc_ctx->gc.is_cached(cgraph);
873 if (reuse) {
874 rpc_msg_graph_recompute_req request;
875 request.device = rpc_ctx->device;
876 auto sock = get_socket(rpc_ctx->endpoint);
877 bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
878 RPC_STATUS_ASSERT(status);
879 } else {
880 rpc_ctx->gc.add(cgraph);
881 std::vector<uint8_t> input;
882 serialize_graph(rpc_ctx->device, cgraph, input);
883 auto sock = get_socket(rpc_ctx->endpoint);
884 bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
885 RPC_STATUS_ASSERT(status);
886 }
887 return GGML_STATUS_SUCCESS;
888}
889
890static ggml_backend_i ggml_backend_rpc_interface = {
891 /* .get_name = */ ggml_backend_rpc_name,
892 /* .free = */ ggml_backend_rpc_free,
893 /* .set_tensor_async = */ NULL,
894 /* .get_tensor_async = */ NULL,
895 /* .cpy_tensor_async = */ NULL,
896 /* .synchronize = */ ggml_backend_rpc_synchronize,
897 /* .graph_plan_create = */ NULL,
898 /* .graph_plan_free = */ NULL,
899 /* .graph_plan_update = */ NULL,
900 /* .graph_plan_compute = */ NULL,
901 /* .graph_compute = */ ggml_backend_rpc_graph_compute,
902 /* .event_record = */ NULL,
903 /* .event_wait = */ NULL,
904 /* .graph_optimize = */ NULL,
905};
906
907ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
908 static std::mutex mutex;
909 std::lock_guard<std::mutex> lock(mutex);
910 std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
911 // NOTE: buffer types are allocated and never freed; this is by design
912 static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
913 auto it = buft_map.find(buft_name);
914 if (it != buft_map.end()) {
915 return it->second;
916 }
917 auto sock = get_socket(endpoint);
918 if (sock == nullptr) {
919 GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
920 return nullptr;
921 }
922 size_t alignment = get_alignment(sock, device);
923 size_t max_size = get_max_size(sock, device);
924 ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
925 /* .endpoint = */ endpoint,
926 /* .device = */ device,
927 /* .name = */ buft_name,
928 /* .alignment = */ alignment,
929 /* .max_size = */ max_size
930 };
931 auto reg = ggml_backend_rpc_add_server(endpoint);
932 ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
933 /* .iface = */ ggml_backend_rpc_buffer_type_interface,
934 /* .device = */ ggml_backend_reg_dev_get(reg, device),
935 /* .context = */ buft_ctx
936 };
937 buft_map[buft_name] = buft;
938 return buft;
939}
940
941ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
942 std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
943 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
944 /* .endpoint = */ endpoint,
945 /* .device = */ device,
946 /* .name = */ dev_name,
947 /* .gc = */ {},
948 };
949 auto reg = ggml_backend_rpc_add_server(endpoint);
950 ggml_backend_t backend = new ggml_backend {
951 /* .guid = */ ggml_backend_rpc_guid(),
952 /* .iface = */ ggml_backend_rpc_interface,
953 /* .device = */ ggml_backend_reg_dev_get(reg, device),
954 /* .context = */ ctx
955 };
956 return backend;
957}
958
959bool ggml_backend_is_rpc(ggml_backend_t backend) {
960 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
961}
962
963static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
964 rpc_msg_get_device_memory_req request;
965 request.device = device;
966 rpc_msg_get_device_memory_rsp response;
967 bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
968 RPC_STATUS_ASSERT(status);
969 *free = response.free_mem;
970 *total = response.total_mem;
971}
972
973void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
974 auto sock = get_socket(endpoint);
975 if (sock == nullptr) {
976 *free = 0;
977 *total = 0;
978 return;
979 }
980 get_device_memory(sock, device, free, total);
981}
982
983// RPC server-side implementation
984
985class rpc_server {
986public:
987 rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
988 : backends(std::move(all_backends)), cache_dir(cache_dir) {
989 stored_graphs.resize(backends.size());
990 }
991 ~rpc_server();
992
993 void hello(rpc_msg_hello_rsp & response);
994 bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
995 bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
996 bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
997 bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
998 bool free_buffer(const rpc_msg_free_buffer_req & request);
999 bool buffer_clear(const rpc_msg_buffer_clear_req & request);
1000 bool set_tensor(const std::vector<uint8_t> & input);
1001 bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
1002 bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
1003 bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
1004 bool graph_compute(const std::vector<uint8_t> & input);
1005 bool graph_recompute(const rpc_msg_graph_recompute_req & request);
1006 bool init_tensor(const rpc_msg_init_tensor_req & request);
1007 bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
1008 bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
1009
1010 struct stored_graph {
1011 ggml_context_ptr ctx_ptr;
1012 ggml_cgraph * graph;
1013 };
1014
1015private:
1016 bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
1017 ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
1018 ggml_tensor * create_node(uint64_t id,
1019 struct ggml_context * ctx,
1020 const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1021 std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
1022
1023
1024 std::vector<ggml_backend_t> backends;
1025 const char * cache_dir;
1026 std::unordered_set<ggml_backend_buffer_t> buffers;
1027 // store the last computed graph for each backend
1028 std::vector<stored_graph> stored_graphs;
1029};
1030
1031void rpc_server::hello(rpc_msg_hello_rsp & response) {
1032 response.major = RPC_PROTO_MAJOR_VERSION;
1033 response.minor = RPC_PROTO_MINOR_VERSION;
1034 response.patch = RPC_PROTO_PATCH_VERSION;
1035 LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
1036}
1037
1038bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
1039 uint32_t dev_id = request.device;
1040 if (dev_id >= backends.size()) {
1041 return false;
1042 }
1043 ggml_backend_buffer_type_t buft;
1044 struct ggml_init_params params {
1045 /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
1046 /*.mem_buffer =*/ NULL,
1047 /*.no_alloc =*/ true,
1048 };
1049
1050 ggml_context_ptr ctx_ptr { ggml_init(params) };
1051 GGML_ASSERT(ctx_ptr != nullptr);
1052 ggml_context * ctx = ctx_ptr.get();
1053
1054 ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1055 if (tensor == nullptr) {
1056 GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
1057 return false;
1058 }
1059 for (int i = 0; i < GGML_MAX_SRC; i++) {
1060 if (request.srcs[i].id != 0) {
1061 tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1062 }
1063 }
1064
1065 LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
1066 if (tensor->buffer == nullptr) {
1067 //No buffer allocated.
1068 buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1069 } else {
1070 buft = tensor->buffer->buft;
1071 }
1072
1073 response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
1074
1075 return true;
1076}
1077
1078bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
1079 uint32_t dev_id = request.device;
1080 if (dev_id >= backends.size()) {
1081 return false;
1082 }
1083 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1084 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
1085 response.remote_ptr = 0;
1086 response.remote_size = 0;
1087 if (buffer != nullptr) {
1088 response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
1089 response.remote_size = buffer->size;
1090 LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
1091 __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
1092 buffers.insert(buffer);
1093 } else {
1094 LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
1095 }
1096 return true;
1097}
1098
1099bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
1100 uint32_t dev_id = request.device;
1101 if (dev_id >= backends.size()) {
1102 return false;
1103 }
1104 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1105 size_t alignment = ggml_backend_buft_get_alignment(buft);
1106 LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
1107 response.alignment = alignment;
1108 return true;
1109}
1110
1111bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
1112 uint32_t dev_id = request.device;
1113 if (dev_id >= backends.size()) {
1114 return false;
1115 }
1116 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1117 size_t max_size = ggml_backend_buft_get_max_size(buft);
1118 LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
1119 response.max_size = max_size;
1120 return true;
1121}
1122
1123bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
1124 LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1125 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1126 if (buffers.find(buffer) == buffers.end()) {
1127 GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1128 return false;
1129 }
1130 void * base = ggml_backend_buffer_get_base(buffer);
1131 response.base_ptr = reinterpret_cast<uint64_t>(base);
1132 return true;
1133}
1134
1135bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
1136 LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1137 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1138 if (buffers.find(buffer) == buffers.end()) {
1139 GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1140 return false;
1141 }
1142 ggml_backend_buffer_free(buffer);
1143 buffers.erase(buffer);
1144 return true;
1145}
1146
1147bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
1148 LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
1149 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1150 if (buffers.find(buffer) == buffers.end()) {
1151 GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1152 return false;
1153 }
1154 ggml_backend_buffer_clear(buffer, request.value);
1155 return true;
1156}
1157
1158ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
1159 // Validate tensor type before using it
1160 if (tensor->type >= GGML_TYPE_COUNT) {
1161 GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
1162 return nullptr;
1163 }
1164
1165 ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
1166 tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1167
1168 // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
1169 if (result == nullptr) {
1170 GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
1171 return nullptr;
1172 }
1173
1174 for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
1175 result->nb[i] = tensor->nb[i];
1176 }
1177 result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
1178 if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
1179 result->buffer = nullptr;
1180 }
1181
1182 if (result->buffer) {
1183 // require that the tensor data does not go beyond the buffer end
1184 uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
1185 uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
1186 uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
1187 GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
1188 GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
1189 }
1190
1191 result->op = (ggml_op) tensor->op;
1192 for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
1193 result->op_params[i] = tensor->op_params[i];
1194 }
1195 result->flags = tensor->flags;
1196 result->data = reinterpret_cast<void *>(tensor->data);
1197 ggml_set_name(result, tensor->name);
1198 return result;
1199}
1200
1201
1202bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
1203 // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
1204 if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
1205 return false;
1206 }
1207 const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1208 uint64_t offset;
1209 memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1210 const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
1211
1212 struct ggml_init_params params {
1213 /*.mem_size =*/ ggml_tensor_overhead(),
1214 /*.mem_buffer =*/ NULL,
1215 /*.no_alloc =*/ true,
1216 };
1217 ggml_context_ptr ctx_ptr { ggml_init(params) };
1218 GGML_ASSERT(ctx_ptr != nullptr);
1219 ggml_context * ctx = ctx_ptr.get();
1220 ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1221 if (tensor == nullptr || tensor->buffer == nullptr) {
1222 GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1223 return false;
1224 }
1225 LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
1226
1227 // sanitize tensor->data
1228 {
1229 const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1230 const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1231
1232 if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1233 GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1234 __func__, in_tensor->data, offset, size, p0, p1);
1235 return false;
1236 }
1237 }
1238
1239 const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1240 if (cache_dir && size > HASH_THRESHOLD) {
1241 uint64_t hash = fnv_hash((const uint8_t*)data, size);
1242 char hash_str[17];
1243 snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1244 // save to cache_dir/hash_str
1245 fs::path cache_file = fs::path(cache_dir) / hash_str;
1246 std::ofstream ofs(cache_file, std::ios::binary);
1247 ofs.write((const char *)data, size);
1248 GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1249 }
1250 ggml_backend_tensor_set(tensor, data, offset, size);
1251 return true;
1252}
1253
1254bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1255 if (!cache_dir) {
1256 return false;
1257 }
1258 char hash_str[17];
1259 snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1260 fs::path cache_file = fs::path(cache_dir) / hash_str;
1261 std::error_code ec;
1262 if (!fs::exists(cache_file, ec)) {
1263 return false;
1264 }
1265 std::ifstream ifs(cache_file, std::ios::binary);
1266 ifs.seekg(0, std::ios::end);
1267 size_t size = ifs.tellg();
1268 ifs.seekg(0, std::ios::beg);
1269 data.resize(size);
1270 ifs.read((char *)data.data(), size);
1271 return true;
1272}
1273
1274bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
1275{
1276 std::vector<uint8_t> cached_file;
1277 if (!get_cached_file(request.hash, cached_file)) {
1278 response.result = 0;
1279 return true;
1280 }
1281 size_t size = cached_file.size();
1282 struct ggml_init_params params {
1283 /*.mem_size =*/ ggml_tensor_overhead(),
1284 /*.mem_buffer =*/ NULL,
1285 /*.no_alloc =*/ true,
1286 };
1287 ggml_context_ptr ctx_ptr { ggml_init(params) };
1288 GGML_ASSERT(ctx_ptr != nullptr);
1289 ggml_context * ctx = ctx_ptr.get();
1290 ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1291 if (tensor == nullptr || tensor->buffer == nullptr) {
1292 GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1293 return false;
1294 }
1295 LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1296 __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1297
1298 // sanitize tensor->data
1299 {
1300 const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1301 const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1302
1303 if (request.tensor.data + request.offset < p0
1304 || request.tensor.data + request.offset >= p1
1305 || size > (p1 - request.tensor.data - request.offset)) {
1306 GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1307 __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
1308 return false;
1309 }
1310 }
1311 ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
1312 response.result = 1;
1313 return true;
1314}
1315
1316bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1317 struct ggml_init_params params {
1318 /*.mem_size =*/ ggml_tensor_overhead(),
1319 /*.mem_buffer =*/ NULL,
1320 /*.no_alloc =*/ true,
1321 };
1322 ggml_context_ptr ctx_ptr { ggml_init(params) };
1323 GGML_ASSERT(ctx_ptr != nullptr);
1324 ggml_context * ctx = ctx_ptr.get();
1325 ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1326 if (tensor == nullptr) {
1327 GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1328 return false;
1329 }
1330 LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
1331 // Call the backend's buffer_init_tensor function
1332 ggml_backend_buffer_t buffer = tensor->buffer;
1333 if (buffer && buffer->iface.init_tensor) {
1334 buffer->iface.init_tensor(buffer, tensor);
1335 } else {
1336 GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
1337 }
1338
1339 if (tensor->extra != nullptr) {
1340 // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1341 // Currently unimplemented.
1342 GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1343 return false;
1344 }
1345
1346 return true;
1347}
1348
1349bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
1350 struct ggml_init_params params {
1351 /*.mem_size =*/ ggml_tensor_overhead(),
1352 /*.mem_buffer =*/ NULL,
1353 /*.no_alloc =*/ true,
1354 };
1355 ggml_context_ptr ctx_ptr { ggml_init(params) };
1356 GGML_ASSERT(ctx_ptr != nullptr);
1357 ggml_context * ctx = ctx_ptr.get();
1358 ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1359 if (tensor == nullptr || tensor->buffer == nullptr) {
1360 GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1361 return false;
1362 }
1363 LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
1364
1365 // sanitize tensor->data
1366 {
1367 const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1368 const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1369
1370 if (request.tensor.data + request.offset < p0 ||
1371 request.tensor.data + request.offset >= p1 ||
1372 request.size > (p1 - request.tensor.data - request.offset)) {
1373 GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1374 __func__, request.tensor.data, request.offset, request.size, p0, p1);
1375 return false;
1376 }
1377 }
1378
1379 response.resize(request.size, 0);
1380 ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
1381 return true;
1382}
1383
1384bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
1385 struct ggml_init_params params {
1386 /*.mem_size =*/ 2*ggml_tensor_overhead(),
1387 /*.mem_buffer =*/ NULL,
1388 /*.no_alloc =*/ true,
1389 };
1390 ggml_context_ptr ctx_ptr { ggml_init(params) };
1391 GGML_ASSERT(ctx_ptr != nullptr);
1392 ggml_context * ctx = ctx_ptr.get();
1393
1394 ggml_tensor * src = deserialize_tensor(ctx, &request.src);
1395 ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1396 if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
1397 GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1398 return false;
1399 }
1400
1401 uint64_t src_size = (uint64_t) ggml_nbytes(src);
1402 uint64_t dst_data = (uint64_t) dst->data;
1403 uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1404 uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1405
1406 if (dst_data + src_size > dst_base + dst_buf_sz) {
1407 GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1408 " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1409 " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1410 __func__,
1411 dst_data,
1412 dst_data + src_size,
1413 dst_base,
1414 dst_base + dst_buf_sz);
1415 return false;
1416 }
1417
1418 LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
1419 __func__, (void*) src->buffer, (void*) dst->buffer);
1420
1421 response.result = ggml_backend_buffer_copy_tensor(src, dst);
1422 return true;
1423}
1424
1425ggml_tensor * rpc_server::create_node(uint64_t id,
1426 struct ggml_context * ctx,
1427 const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1428 std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
1429 if (tensor_map.find(id) != tensor_map.end()) {
1430 return tensor_map[id];
1431 }
1432 // Safely find the tensor pointer
1433 auto it_ptr = tensor_ptrs.find(id);
1434 if (it_ptr == tensor_ptrs.end()) {
1435 return nullptr;
1436 }
1437 const rpc_tensor * tensor = it_ptr->second;
1438
1439 struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
1440 if (result == nullptr) {
1441 return nullptr;
1442 }
1443 tensor_map[id] = result;
1444 for (int i = 0; i < GGML_MAX_SRC; i++) {
1445 // Check if the source ID is 0 before calling create_node recursively
1446 if (tensor->src[i] == 0) {
1447 result->src[i] = nullptr;
1448 } else {
1449 result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1450 // If the recursive call failed for a non-zero ID, propagate the error
1451 if (result->src[i] == nullptr) {
1452 GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1453 __func__, i, tensor->src[i], id);
1454 // Must return nullptr to signal failure up the call stack
1455 return nullptr;
1456 }
1457 }
1458 }
1459
1460 // Handle view_src similarly
1461 if (tensor->view_src == 0) {
1462 result->view_src = nullptr;
1463 } else {
1464 result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1465 // If the recursive call failed for a non-zero ID, propagate the error
1466 if (result->view_src == nullptr) {
1467 GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1468 __func__, tensor->view_src, id);
1469 // Must return nullptr to signal failure up the call stack
1470 return nullptr;
1471 }
1472 }
1473 result->view_offs = tensor->view_offs;
1474 return result;
1475}
1476
1477bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
1478 // serialization format:
1479 // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1480 if (input.size() < 2*sizeof(uint32_t)) {
1481 return false;
1482 }
1483 const uint8_t * src = input.data();
1484 uint32_t device;
1485 memcpy(&device, src, sizeof(device));
1486 src += sizeof(device);
1487 if (device >= backends.size()) {
1488 return false;
1489 }
1490 uint32_t n_nodes;
1491 memcpy(&n_nodes, src, sizeof(n_nodes));
1492 src += sizeof(n_nodes);
1493 if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1494 return false;
1495 }
1496 const uint64_t * nodes = (const uint64_t *)src;
1497 src += n_nodes*sizeof(uint64_t);
1498 uint32_t n_tensors;
1499 memcpy(&n_tensors, src, sizeof(n_tensors));
1500 src += sizeof(n_tensors);
1501 if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1502 return false;
1503 }
1504 const rpc_tensor * tensors = (const rpc_tensor *)src;
1505 LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
1506
1507 size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1508
1509 struct ggml_init_params params = {
1510 /*.mem_size =*/ buf_size,
1511 /*.mem_buffer =*/ NULL,
1512 /*.no_alloc =*/ true,
1513 };
1514 ggml_context_ptr ctx_ptr { ggml_init(params) };
1515 GGML_ASSERT(ctx_ptr != nullptr);
1516 ggml_context * ctx = ctx_ptr.get();
1517 struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1518 graph->n_nodes = n_nodes;
1519 std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1520 tensor_ptrs.reserve(n_tensors);
1521 for (uint32_t i = 0; i < n_tensors; i++) {
1522 tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
1523 }
1524 std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1525 tensor_map.reserve(n_nodes);
1526 for (uint32_t i = 0; i < n_nodes; i++) {
1527 int64_t id;
1528 memcpy(&id, &nodes[i], sizeof(id));
1529 graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1530
1531 // Check if create_node failed for a *non-zero* ID.
1532 // If id was 0, create_node returning nullptr is expected.
1533 // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1534 if (graph->nodes[i] == nullptr && id != 0) {
1535 GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1536 return false;
1537 }
1538 }
1539 ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1540 GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1541 stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1542 stored_graphs[device].graph = graph;
1543 return true;
1544}
1545
1546bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1547 uint32_t device = request.device;
1548 if (device >= backends.size()) {
1549 return false;
1550 }
1551 if (stored_graphs[device].graph == nullptr) {
1552 return false;
1553 }
1554 ggml_cgraph * graph = stored_graphs[device].graph;
1555 LOG_DBG("[%s] device: %u\n", __func__, device);
1556 ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1557 GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1558 return true;
1559}
1560
1561bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1562 uint32_t dev_id = request.device;
1563 if (dev_id >= backends.size()) {
1564 return false;
1565 }
1566 size_t free, total;
1567 ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1568 ggml_backend_dev_memory(dev, &free, &total);
1569 response.free_mem = free;
1570 response.total_mem = total;
1571 LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1572 return true;
1573}
1574
1575rpc_server::~rpc_server() {
1576 for (auto buffer : buffers) {
1577 ggml_backend_buffer_free(buffer);
1578 }
1579}
1580
1581static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1582 sockfd_t sockfd) {
1583 rpc_server server(backends, cache_dir);
1584 uint8_t cmd;
1585 if (!recv_data(sockfd, &cmd, 1)) {
1586 return;
1587 }
1588 // the first command sent by the client must be HELLO
1589 if (cmd != RPC_CMD_HELLO) {
1590 GGML_LOG_ERROR("Expected HELLO command, update client\n");
1591 return;
1592 }
1593 if (!recv_msg(sockfd, nullptr, 0)) {
1594 return;
1595 }
1596 rpc_msg_hello_rsp response;
1597 server.hello(response);
1598 if (!send_msg(sockfd, &response, sizeof(response))) {
1599 return;
1600 }
1601 while (true) {
1602 if (!recv_data(sockfd, &cmd, 1)) {
1603 break;
1604 }
1605 if (cmd >= RPC_CMD_COUNT) {
1606 // fail fast if the command is invalid
1607 GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1608 break;
1609 }
1610 switch (cmd) {
1611 case RPC_CMD_HELLO: {
1612 // HELLO command is handled above
1613 return;
1614 }
1615 case RPC_CMD_DEVICE_COUNT: {
1616 if (!recv_msg(sockfd, nullptr, 0)) {
1617 return;
1618 }
1619 rpc_msg_device_count_rsp response;
1620 response.device_count = backends.size();
1621 if (!send_msg(sockfd, &response, sizeof(response))) {
1622 return;
1623 }
1624 break;
1625 }
1626 case RPC_CMD_ALLOC_BUFFER: {
1627 rpc_msg_alloc_buffer_req request;
1628 if (!recv_msg(sockfd, &request, sizeof(request))) {
1629 return;
1630 }
1631 rpc_msg_alloc_buffer_rsp response;
1632 if (!server.alloc_buffer(request, response)) {
1633 return;
1634 }
1635 if (!send_msg(sockfd, &response, sizeof(response))) {
1636 return;
1637 }
1638 break;
1639 }
1640 case RPC_CMD_GET_ALLOC_SIZE: {
1641 rpc_msg_get_alloc_size_req request;
1642 if (!recv_msg(sockfd, &request, sizeof(request))) {
1643 return;
1644 }
1645 rpc_msg_get_alloc_size_rsp response;
1646 if (!server.get_alloc_size(request, response)) {
1647 return;
1648 }
1649 if (!send_msg(sockfd, &response, sizeof(response))) {
1650 return;
1651 }
1652 break;
1653 }
1654 case RPC_CMD_GET_ALIGNMENT: {
1655 rpc_msg_get_alignment_req request;
1656 if (!recv_msg(sockfd, &request, sizeof(request))) {
1657 return;
1658 }
1659 rpc_msg_get_alignment_rsp response;
1660 if (!server.get_alignment(request, response)) {
1661 return;
1662 }
1663 if (!send_msg(sockfd, &response, sizeof(response))) {
1664 return;
1665 }
1666 break;
1667 }
1668 case RPC_CMD_GET_MAX_SIZE: {
1669 rpc_msg_get_max_size_req request;
1670 if (!recv_msg(sockfd, &request, sizeof(request))) {
1671 return;
1672 }
1673 rpc_msg_get_max_size_rsp response;
1674 if (!server.get_max_size(request, response)) {
1675 return;
1676 }
1677 if (!send_msg(sockfd, &response, sizeof(response))) {
1678 return;
1679 }
1680 break;
1681 }
1682 case RPC_CMD_BUFFER_GET_BASE: {
1683 rpc_msg_buffer_get_base_req request;
1684 if (!recv_msg(sockfd, &request, sizeof(request))) {
1685 return;
1686 }
1687 rpc_msg_buffer_get_base_rsp response;
1688 if (!server.buffer_get_base(request, response)) {
1689 return;
1690 }
1691 if (!send_msg(sockfd, &response, sizeof(response))) {
1692 return;
1693 }
1694 break;
1695 }
1696 case RPC_CMD_FREE_BUFFER: {
1697 rpc_msg_free_buffer_req request;
1698 if (!recv_msg(sockfd, &request, sizeof(request))) {
1699 return;
1700 }
1701 if (!server.free_buffer(request)) {
1702 return;
1703 }
1704 if (!send_msg(sockfd, nullptr, 0)) {
1705 return;
1706 }
1707 break;
1708 }
1709 case RPC_CMD_BUFFER_CLEAR: {
1710 rpc_msg_buffer_clear_req request;
1711 if (!recv_msg(sockfd, &request, sizeof(request))) {
1712 return;
1713 }
1714 if (!server.buffer_clear(request)) {
1715 return;
1716 }
1717 if (!send_msg(sockfd, nullptr, 0)) {
1718 return;
1719 }
1720 break;
1721 }
1722 case RPC_CMD_SET_TENSOR: {
1723 std::vector<uint8_t> input;
1724 if (!recv_msg(sockfd, input)) {
1725 return;
1726 }
1727 if (!server.set_tensor(input)) {
1728 return;
1729 }
1730 break;
1731 }
1732 case RPC_CMD_SET_TENSOR_HASH: {
1733 rpc_msg_set_tensor_hash_req request;
1734 if (!recv_msg(sockfd, &request, sizeof(request))) {
1735 return;
1736 }
1737 rpc_msg_set_tensor_hash_rsp response;
1738 if (!server.set_tensor_hash(request, response)) {
1739 return;
1740 }
1741 if (!send_msg(sockfd, &response, sizeof(response))) {
1742 return;
1743 }
1744 break;
1745 }
1746 case RPC_CMD_INIT_TENSOR: {
1747 rpc_msg_init_tensor_req request;
1748 if (!recv_msg(sockfd, &request,sizeof(request))) {
1749 return;
1750 }
1751 if (!server.init_tensor(request)) {
1752 return;
1753 }
1754 if (!send_msg(sockfd, nullptr, 0)) {
1755 return;
1756 }
1757 break;
1758 }
1759 case RPC_CMD_GET_TENSOR: {
1760 rpc_msg_get_tensor_req request;
1761 if (!recv_msg(sockfd, &request, sizeof(request))) {
1762 return;
1763 }
1764 std::vector<uint8_t> response;
1765 if (!server.get_tensor(request, response)) {
1766 return;
1767 }
1768 if (!send_msg(sockfd, response.data(), response.size())) {
1769 return;
1770 }
1771 break;
1772 }
1773 case RPC_CMD_COPY_TENSOR: {
1774 rpc_msg_copy_tensor_req request;
1775 if (!recv_msg(sockfd, &request, sizeof(request))) {
1776 return;
1777 }
1778 rpc_msg_copy_tensor_rsp response;
1779 if (!server.copy_tensor(request, response)) {
1780 return;
1781 }
1782 if (!send_msg(sockfd, &response, sizeof(response))) {
1783 return;
1784 }
1785 break;
1786 }
1787 case RPC_CMD_GRAPH_COMPUTE: {
1788 std::vector<uint8_t> input;
1789 if (!recv_msg(sockfd, input)) {
1790 return;
1791 }
1792 if (!server.graph_compute(input)) {
1793 return;
1794 }
1795 break;
1796 }
1797 case RPC_CMD_GRAPH_RECOMPUTE: {
1798 rpc_msg_graph_recompute_req request;
1799 if (!recv_msg(sockfd, &request, sizeof(request))) {
1800 return;
1801 }
1802 if (!server.graph_recompute(request)) {
1803 return;
1804 }
1805 break;
1806 }
1807 case RPC_CMD_GET_DEVICE_MEMORY: {
1808 rpc_msg_get_device_memory_req request;
1809 if (!recv_msg(sockfd, &request, sizeof(request))) {
1810 return;
1811 }
1812 rpc_msg_get_device_memory_rsp response;
1813 if (!server.get_device_memory(request, response)) {
1814 return;
1815 }
1816 if (!send_msg(sockfd, &response, sizeof(response))) {
1817 return;
1818 }
1819 break;
1820 }
1821 default: {
1822 GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1823 return;
1824 }
1825 }
1826 }
1827}
1828
1829void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1830 size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1831 if (n_devices == 0 || devices == nullptr) {
1832 fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
1833 return;
1834 }
1835 std::vector<ggml_backend_t> backends;
1836 printf("Starting RPC server v%d.%d.%d\n",
1837 RPC_PROTO_MAJOR_VERSION,
1838 RPC_PROTO_MINOR_VERSION,
1839 RPC_PROTO_PATCH_VERSION);
1840 printf(" endpoint : %s\n", endpoint);
1841 printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1842 printf("Devices:\n");
1843 for (size_t i = 0; i < n_devices; i++) {
1844 auto dev = devices[i];
1845 size_t free, total;
1846 ggml_backend_dev_memory(dev, &free, &total);
1847 printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1848 total / 1024 / 1024, free / 1024 / 1024);
1849 auto backend = ggml_backend_dev_init(dev, nullptr);
1850 if (!backend) {
1851 fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
1852 return;
1853 }
1854 backends.push_back(backend);
1855 ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1856 if (reg) {
1857 auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1858 if (ggml_backend_set_n_threads_fn) {
1859 ggml_backend_set_n_threads_fn(backend, n_threads);
1860 }
1861 }
1862 }
1863
1864 std::string host;
1865 int port;
1866 if (!parse_endpoint(endpoint, host, port)) {
1867 return;
1868 }
1869#ifdef _WIN32
1870 {
1871 WSADATA wsaData;
1872 int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
1873 if (res != 0) {
1874 fprintf(stderr, "WSAStartup failed: %d\n", res);
1875 return;
1876 }
1877 }
1878#endif
1879 auto server_socket = create_server_socket(host.c_str(), port);
1880 if (server_socket == nullptr) {
1881 fprintf(stderr, "Failed to create server socket\n");
1882 return;
1883 }
1884 while (true) {
1885 auto client_socket = socket_accept(server_socket->fd);
1886 if (client_socket == nullptr) {
1887 fprintf(stderr, "Failed to accept client connection\n");
1888 return;
1889 }
1890 printf("Accepted client connection\n");
1891 fflush(stdout);
1892 rpc_serve_client(backends, cache_dir, client_socket->fd);
1893 printf("Client connection closed\n");
1894 fflush(stdout);
1895 }
1896#ifdef _WIN32
1897 WSACleanup();
1898#endif
1899 for (auto backend : backends) {
1900 ggml_backend_free(backend);
1901 }
1902}
1903
1904// device interface
1905
1906struct ggml_backend_rpc_device_context {
1907 std::string endpoint;
1908 uint32_t device;
1909 std::string name;
1910 std::string description;
1911};
1912
1913static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1914 ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1915
1916 return ctx->name.c_str();
1917}
1918
1919static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1920 ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1921
1922 return ctx->description.c_str();
1923}
1924
1925static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1926 ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1927
1928 ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
1929}
1930
1931static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1932 // TODO: obtain value from the server
1933 return GGML_BACKEND_DEVICE_TYPE_GPU;
1934
1935 GGML_UNUSED(dev);
1936}
1937
1938static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1939 props->name = ggml_backend_rpc_device_get_name(dev);
1940 props->description = ggml_backend_rpc_device_get_description(dev);
1941 props->type = ggml_backend_rpc_device_get_type(dev);
1942 ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1943 props->caps = {
1944 /* .async = */ false,
1945 /* .host_buffer = */ false,
1946 /* .buffer_from_host_ptr = */ false,
1947 /* .events = */ false,
1948 };
1949}
1950
1951static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1952 ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1953
1954 return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
1955
1956 GGML_UNUSED(params);
1957}
1958
1959static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1960 ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1961
1962 return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
1963
1964 GGML_UNUSED(dev);
1965}
1966
1967static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1968 GGML_UNUSED(dev);
1969 GGML_UNUSED(op);
1970 //TODO: call the remote backend and cache the results
1971 return true;
1972}
1973
1974static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1975 if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1976 return false;
1977 }
1978 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1979 ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1980 return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
1981}
1982
1983static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1984 /* .get_name = */ ggml_backend_rpc_device_get_name,
1985 /* .get_description = */ ggml_backend_rpc_device_get_description,
1986 /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1987 /* .get_type = */ ggml_backend_rpc_device_get_type,
1988 /* .get_props = */ ggml_backend_rpc_device_get_props,
1989 /* .init_backend = */ ggml_backend_rpc_device_init,
1990 /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1991 /* .get_host_buffer_type = */ NULL,
1992 /* .buffer_from_host_ptr = */ NULL,
1993 /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1994 /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1995 /* .offload_op = */ NULL,
1996 /* .event_new = */ NULL,
1997 /* .event_free = */ NULL,
1998 /* .event_synchronize = */ NULL,
1999};
2000
2001// backend reg interface
2002
2003struct ggml_backend_rpc_reg_context {
2004 std::string name;
2005 std::vector<ggml_backend_dev_t> devices;
2006};
2007
2008static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
2009 ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2010 return ctx ? ctx->name.c_str() : "RPC";
2011}
2012
2013static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
2014 ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2015 return ctx ? ctx->devices.size() : 0;
2016}
2017
2018static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2019 ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2020 if (ctx == nullptr) {
2021 GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
2022 } else {
2023 GGML_ASSERT(index < ctx->devices.size());
2024 return ctx->devices[index];
2025 }
2026}
2027
2028static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2029 if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
2030 return (void *)ggml_backend_rpc_add_server;
2031 }
2032 if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
2033 return (void *)ggml_backend_rpc_start_server;
2034 }
2035 return NULL;
2036
2037 GGML_UNUSED(reg);
2038}
2039
2040static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
2041 /* .get_name = */ ggml_backend_rpc_reg_get_name,
2042 /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2043 /* .get_device = */ ggml_backend_rpc_reg_get_device,
2044 /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2045};
2046
2047ggml_backend_reg_t ggml_backend_rpc_reg(void) {
2048 static struct ggml_backend_reg ggml_backend_rpc_reg = {
2049 /* .api_version = */ GGML_BACKEND_API_VERSION,
2050 /* .iface = */ ggml_backend_rpc_reg_i,
2051 /* .context = */ NULL,
2052 };
2053
2054 return &ggml_backend_rpc_reg;
2055}
2056
2057static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
2058 auto sock = get_socket(endpoint);
2059 if (sock == nullptr) {
2060 GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
2061 return 0;
2062 }
2063 rpc_msg_device_count_rsp response;
2064 bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
2065 RPC_STATUS_ASSERT(status);
2066 return response.device_count;
2067}
2068
2069static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
2070 /* .get_name = */ ggml_backend_rpc_reg_get_name,
2071 /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2072 /* .get_device = */ ggml_backend_rpc_reg_get_device,
2073 /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2074};
2075
2076ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
2077 static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
2078 static std::mutex mutex;
2079 static uint32_t dev_id = 0;
2080 std::lock_guard<std::mutex> lock(mutex);
2081 if (reg_map.find(endpoint) != reg_map.end()) {
2082 return reg_map[endpoint];
2083 }
2084 uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
2085 if (dev_count == 0) {
2086 return nullptr;
2087 }
2088 ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
2089 ctx->name = "RPC[" + std::string(endpoint) + "]";
2090 for (uint32_t ind = 0; ind < dev_count; ind++) {
2091 std::string dev_name = "RPC" + std::to_string(dev_id);
2092 std::string dev_desc = std::string(endpoint);
2093 ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
2094 /* .endpoint = */ endpoint,
2095 /* .device = */ ind,
2096 /* .name = */ dev_name,
2097 /* .description = */ dev_desc
2098 };
2099
2100 ggml_backend_dev_t dev = new ggml_backend_device {
2101 /* .iface = */ ggml_backend_rpc_device_i,
2102 /* .reg = */ ggml_backend_rpc_reg(),
2103 /* .context = */ dev_ctx,
2104 };
2105 ctx->devices.push_back(dev);
2106 dev_id++;
2107 }
2108 ggml_backend_reg_t reg = new ggml_backend_reg {
2109 /* .api_version = */ GGML_BACKEND_API_VERSION,
2110 /* .iface = */ ggml_backend_rpc_reg_interface,
2111 /* .context = */ ctx
2112 };
2113 reg_map[endpoint] = reg;
2114 return reg;
2115}
2116
2117
2118GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)