1#include "ggml-metal.h"
  2
  3#include "ggml-impl.h"
  4#include "ggml-backend-impl.h"
  5
  6#include "ggml-metal-device.h"
  7#include "ggml-metal-context.h"
  8#include "ggml-metal-ops.h"
  9
 10#include <mutex>
 11#include <string>
 12
 13#define GGML_METAL_NAME "MTL"
 14#define GGML_METAL_MAX_DEVICES 16
 15
 16// number of Metal devices
 17// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices
 18static int g_devices = 1;
 19
 20////////////////////////////////////////////////////////////////////////////////
 21// backend interface
 22////////////////////////////////////////////////////////////////////////////////
 23
 24// shared buffer
 25
 26static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {
 27    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 28
 29    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 30
 31    ggml_metal_buffer_free(ctx);
 32}
 33
 34static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
 35    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 36
 37    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 38
 39    return ggml_metal_buffer_get_base(ctx);
 40}
 41
 42static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
 43    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 44
 45    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 46
 47    ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
 48}
 49
 50static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 51    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 52
 53    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 54
 55    ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
 56}
 57
 58static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 59    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 60
 61    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 62
 63    ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
 64}
 65
 66static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
 67    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 68
 69    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 70
 71    GGML_UNUSED(buffer);
 72    GGML_UNUSED(src);
 73    GGML_UNUSED(dst);
 74
 75    return false;
 76}
 77
 78static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
 79    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
 80
 81    GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
 82
 83    ggml_metal_buffer_clear(ctx, value);
 84}
 85
 86static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
 87    /* .free_buffer     = */ ggml_backend_metal_buffer_shared_free_buffer,
 88    /* .get_base        = */ ggml_backend_metal_buffer_shared_get_base,
 89    /* .init_tensor     = */ NULL,
 90    /* .memset_tensor   = */ ggml_backend_metal_buffer_shared_memset_tensor,
 91    /* .set_tensor      = */ ggml_backend_metal_buffer_shared_set_tensor,
 92    /* .get_tensor      = */ ggml_backend_metal_buffer_shared_get_tensor,
 93    /* .cpy_tensor      = */ ggml_backend_metal_buffer_shared_cpy_tensor,
 94    /* .clear           = */ ggml_backend_metal_buffer_shared_clear,
 95    /* .reset           = */ NULL,
 96};
 97
 98// private buffer
 99
100static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {
101    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
102
103    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
104
105    ggml_metal_buffer_free(ctx);
106}
107
108static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
109    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
110
111    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
112
113    return ggml_metal_buffer_get_base(ctx);
114}
115
116static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
117    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
118
119    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
120
121    ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
122}
123
124static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
125    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
126
127    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
128
129    ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
130}
131
132static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
133    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
134
135    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
136
137    ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
138}
139
140static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
141    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
142
143    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
144
145    GGML_UNUSED(buffer);
146    GGML_UNUSED(src);
147    GGML_UNUSED(dst);
148
149    return false;
150}
151
152static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
153    ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
154
155    GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
156
157    ggml_metal_buffer_clear(ctx, value);
158}
159
160static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
161    /* .free_buffer     = */ ggml_backend_metal_buffer_private_free_buffer,
162    /* .get_base        = */ ggml_backend_metal_buffer_private_get_base,
163    /* .init_tensor     = */ NULL,
164    /* .memset_tensor   = */ ggml_backend_metal_buffer_private_memset_tensor,
165    /* .set_tensor      = */ ggml_backend_metal_buffer_private_set_tensor,
166    /* .get_tensor      = */ ggml_backend_metal_buffer_private_get_tensor,
167    /* .cpy_tensor      = */ ggml_backend_metal_buffer_private_cpy_tensor,
168    /* .clear           = */ ggml_backend_metal_buffer_private_clear,
169    /* .reset           = */ NULL,
170};
171
172static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
173    return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer ||
174           buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer;
175}
176
177//
178// buffer types
179//
180
181struct ggml_backend_metal_buffer_type {
182    int device;
183    std::string name;
184};
185
186struct ggml_backend_metal_buffer_type_deleter {
187    void operator()(ggml_backend_metal_buffer_type * ctx) const {
188        delete ctx;
189    }
190};
191
192typedef std::unique_ptr<ggml_backend_metal_buffer_type, ggml_backend_metal_buffer_type_deleter> ggml_backend_metal_buffer_type_ptr;
193
194// common method for allocating shread or private Metal buffers
195static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
196    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
197    ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared);
198
199    ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res)
200        ? ggml_backend_metal_buffer_shared_i
201        : ggml_backend_metal_buffer_private_i;
202
203    return ggml_backend_buffer_init(buft, buf_i, res, size);
204}
205
206static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
207    size_t res = ggml_nbytes(tensor);
208
209    // some operations require additional memory for fleeting data:
210    switch (tensor->op) {
211        case GGML_OP_MUL_MAT_ID:
212            {
213                res += ggml_metal_op_mul_mat_id_extra_tpe(tensor);
214                res += ggml_metal_op_mul_mat_id_extra_ids(tensor);
215            } break;
216        case GGML_OP_FLASH_ATTN_EXT:
217            {
218                res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
219                res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
220                res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
221            } break;
222        case GGML_OP_CUMSUM:
223        case GGML_OP_ARGSORT:
224            {
225                res *= 2;
226            } break;
227        case GGML_OP_TOP_K:
228            {
229                res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
230            } break;
231        default:
232            break;
233    }
234
235    return res;
236
237    GGML_UNUSED(buft);
238}
239
240// default (shared) buffer type
241
242static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
243    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
244
245    return ctx->name.c_str();
246}
247
248static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
249    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
250}
251
252static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {
253    return 32;
254
255    GGML_UNUSED(buft);
256}
257
258static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {
259    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
260
261    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
262}
263
264static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
265    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
266}
267
268static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
269    return false;
270
271    GGML_UNUSED(buft);
272}
273
274static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) {
275    static std::mutex mutex;
276    std::lock_guard<std::mutex> lock(mutex);
277
278    static std::vector<ggml_backend_buffer_type> bufts;
279    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
280
281    static bool initialized = false;
282    if (!initialized) {
283        bufts.reserve(g_devices);
284        ctxs.reserve(g_devices);
285
286        for (int i = 0; i < g_devices; ++i) {
287            ggml_backend_metal_buffer_type * raw_ctx =
288                new ggml_backend_metal_buffer_type {
289                    /* .device = */ i,
290                    /* .name   = */ GGML_METAL_NAME + std::to_string(i),
291                };
292            ctxs.emplace_back(raw_ctx);
293
294            ggml_backend_buffer_type buft = {
295                /* .iface = */ {
296                    /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,
297                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
298                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
299                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
300                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
301                    /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
302                },
303                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
304                /* .context = */ raw_ctx,
305            };
306
307            bufts.emplace_back(buft);
308        }
309
310        initialized = true;
311    }
312
313    return &bufts[device];
314}
315
316// default (private) buffer type
317
318static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
319    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
320
321    return ctx->name.c_str();
322}
323
324static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
325    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);
326}
327
328static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {
329    return 32;
330
331    GGML_UNUSED(buft);
332}
333
334static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {
335    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
336
337    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
338}
339
340static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
341    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
342}
343
344static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
345    return false;
346
347    GGML_UNUSED(buft);
348}
349
350static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) {
351    static std::mutex mutex;
352    std::lock_guard<std::mutex> lock(mutex);
353
354    static std::vector<ggml_backend_buffer_type> bufts;
355    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
356
357    static bool initialized = false;
358    if (!initialized) {
359        bufts.reserve(g_devices);
360        ctxs.reserve(g_devices);
361
362        for (int i = 0; i < g_devices; ++i) {
363            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
364                /* .device = */ i,
365                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Private"
366            };
367            ctxs.emplace_back(raw_ctx);
368
369            ggml_backend_buffer_type buft = {
370                /* .iface = */ {
371                    /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,
372                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
373                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
374                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
375                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
376                    /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
377                },
378                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
379                /* .context = */ raw_ctx,
380            };
381
382            bufts.emplace_back(buft);
383        }
384
385        initialized = true;
386    }
387
388    return &bufts[device];
389}
390
391// mapped buffer type
392
393static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
394    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
395
396    return ctx->name.c_str();
397}
398
399static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
400    // for mapped buffers, prefer shared memory
401    return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
402}
403
404static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {
405    return 32;
406
407    GGML_UNUSED(buft);
408}
409
410static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {
411    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
412
413    return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
414}
415
416static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
417    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
418}
419
420static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
421    return false;
422
423    GGML_UNUSED(buft);
424}
425
426static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) {
427    static std::mutex mutex;
428    std::lock_guard<std::mutex> lock(mutex);
429
430    static std::vector<ggml_backend_buffer_type> bufts;
431    static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
432
433    static bool initialized = false;
434    if (!initialized) {
435        bufts.reserve(g_devices);
436        ctxs.reserve(g_devices);
437
438        for (int i = 0; i < g_devices; ++i) {
439            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
440                /* .device = */ i,
441                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped"
442            };
443            ctxs.emplace_back(raw_ctx);
444
445            // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
446            //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
447            ggml_backend_buffer_type buft = {
448                /* .iface = */ {
449                    /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,
450                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
451                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
452                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
453                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
454                    /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
455                },
456                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
457                /* .context = */ raw_ctx,
458            };
459
460            bufts.emplace_back(buft);
461        }
462
463        initialized = true;
464    }
465
466    return &bufts[device];
467}
468
469// backend
470
471static const char * ggml_backend_metal_name(ggml_backend_t backend) {
472    ggml_metal_t ctx = (ggml_metal_t)backend->context;
473
474    return ggml_metal_get_name(ctx);
475}
476
477static void ggml_backend_metal_free(ggml_backend_t backend) {
478    ggml_metal_t ctx = (ggml_metal_t)backend->context;
479
480    // wait for any ongoing async operations to finish
481    ggml_metal_synchronize(ctx);
482
483    ggml_metal_free(ctx);
484
485    free(backend);
486}
487
488static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
489    ggml_metal_t ctx = (ggml_metal_t)backend->context;
490
491    ggml_metal_synchronize(ctx);
492}
493
494static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
495    ggml_metal_t ctx = (ggml_metal_t)backend->context;
496
497    ggml_metal_set_tensor_async(ctx, tensor, data, offset, size);
498}
499
500static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
501    ggml_metal_t ctx = (ggml_metal_t)backend->context;
502
503    ggml_metal_get_tensor_async(ctx, tensor, data, offset, size);
504}
505
506static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
507    if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) {
508        return false;
509    }
510
511    if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) {
512        return false;
513    }
514
515    ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context;
516    ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context;
517
518    //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
519    //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
520
521    //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context;
522    //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context;
523
524    return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);
525}
526
527static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
528    ggml_metal_t ctx = (ggml_metal_t)backend->context;
529
530    return ggml_metal_graph_compute(ctx, cgraph);
531}
532
533static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
534    ggml_metal_t ctx = (ggml_metal_t)backend->context;
535    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
536
537    ggml_metal_event_record(ctx, ev);
538}
539
540static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
541    ggml_metal_t ctx = (ggml_metal_t)backend->context;
542    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
543
544    ggml_metal_event_wait(ctx, ev);
545}
546
547static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
548    ggml_metal_t ctx = (ggml_metal_t)backend->context;
549
550    ggml_metal_graph_optimize(ctx, cgraph);
551}
552
553static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
554    GGML_ASSERT(ggml_backend_is_metal(backend));
555
556    ggml_metal_t ctx = (ggml_metal_t)backend->context;
557
558    ggml_metal_set_n_cb(ctx, n_cb);
559}
560
561static ggml_backend_i ggml_backend_metal_i = {
562    /* .get_name                = */ ggml_backend_metal_name,
563    /* .free                    = */ ggml_backend_metal_free,
564    /* .set_tensor_async        = */ ggml_backend_metal_set_tensor_async,
565    /* .get_tensor_async        = */ ggml_backend_metal_get_tensor_async,
566    /* .cpy_tensor_async        = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
567    /* .synchronize             = */ ggml_backend_metal_synchronize,
568    /* .graph_plan_create       = */ NULL,
569    /* .graph_plan_free         = */ NULL,
570    /* .graph_plan_update       = */ NULL,
571    /* .graph_plan_compute      = */ NULL,
572    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
573    /* .event_record            = */ ggml_backend_metal_event_record,
574    /* .event_wait              = */ ggml_backend_metal_event_wait,
575    /* .graph_optimize          = */ ggml_backend_metal_graph_optimize,
576};
577
578static ggml_guid_t ggml_backend_metal_guid(void) {
579    static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
580    return &guid;
581}
582
583ggml_backend_t ggml_backend_metal_init(void) {
584    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
585    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
586
587    ggml_metal_t ctx = ggml_metal_init(ctx_dev);
588    if (ctx == NULL) {
589        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
590        return NULL;
591    }
592
593    ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
594
595    *backend = {
596        /* .guid      = */ ggml_backend_metal_guid(),
597        /* .interface = */ ggml_backend_metal_i,
598        /* .device    = */ dev,
599        /* .context   = */ ctx,
600    };
601
602    ggml_backend_metal_set_n_cb(backend, 1);
603
604    return backend;
605}
606
607bool ggml_backend_is_metal(ggml_backend_t backend) {
608    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
609}
610
611void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
612    GGML_ASSERT(ggml_backend_is_metal(backend));
613
614    ggml_metal_t ctx = (ggml_metal_t)backend->context;
615
616    ggml_metal_set_abort_callback(ctx, abort_callback, user_data);
617}
618
619bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
620    GGML_ASSERT(ggml_backend_is_metal(backend));
621
622    ggml_metal_t ctx = (ggml_metal_t)backend->context;
623
624    return ggml_metal_supports_family(ctx, family);
625}
626
627void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
628    GGML_ASSERT(ggml_backend_is_metal(backend));
629
630    ggml_metal_t ctx = (ggml_metal_t)backend->context;
631
632    ggml_metal_capture_next_compute(ctx);
633}
634
635// backend device
636
637static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
638    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
639
640    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
641
642    return props_dev->name;
643}
644
645static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
646    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
647
648    return ggml_metal_device_get_props(ctx_dev)->desc;
649}
650
651static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
652    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
653
654    ggml_metal_device_get_memory(ctx_dev, free, total);
655}
656
657static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
658    return GGML_BACKEND_DEVICE_TYPE_GPU;
659
660    GGML_UNUSED(dev);
661}
662
663static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
664    props->name        = ggml_backend_metal_device_get_name(dev);
665    props->description = ggml_backend_metal_device_get_description(dev);
666    props->type        = ggml_backend_metal_device_get_type(dev);
667
668    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
669
670    props->caps = {
671        /* .async                = */ true,
672        /* .host_buffer          = */ false,
673        /* .buffer_from_host_ptr = */ true,
674        /* .events               = */ true,
675    };
676}
677
678static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) {
679    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
680
681    ggml_metal_t ctx = ggml_metal_init(ctx_dev);
682    if (ctx == NULL) {
683        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
684        return NULL;
685    }
686
687    ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
688
689    *backend = {
690        /* .guid      = */ ggml_backend_metal_guid(),
691        /* .interface = */ ggml_backend_metal_i,
692        /* .device    = */ dev,
693        /* .context   = */ ctx,
694    };
695
696    ggml_backend_metal_set_n_cb(backend, 1);
697
698    return backend;
699
700    GGML_UNUSED(params);
701}
702
703static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
704    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
705
706    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
707
708    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device);
709}
710
711static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
712    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
713
714    ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
715
716    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
717
718    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size);
719}
720
721static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
722    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
723
724    return ggml_metal_device_supports_op(ctx_dev, op);
725}
726
727static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
728    return
729        buft->device == dev && (
730        buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
731        buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
732        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name);
733
734    GGML_UNUSED(dev);
735}
736
737static int64_t get_op_batch_size(const ggml_tensor * op) {
738    switch (op->op) {
739        case GGML_OP_MUL_MAT:
740            return op->ne[1];
741        case GGML_OP_MUL_MAT_ID:
742            return op->ne[2];
743        default:
744            return ggml_nrows(op);
745    }
746}
747
748static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
749    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
750
751    return (op->op == GGML_OP_MUL_MAT ||
752            op->op == GGML_OP_MUL_MAT_ID) &&
753            get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
754}
755
756static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) {
757    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
758
759    ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev);
760    GGML_ASSERT(event);
761
762    ggml_backend_event_t ev = new ggml_backend_event {
763        /* .device  = */ dev,
764        /* .context = */ event,
765    };
766
767    return ev;
768}
769
770static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
771    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
772
773    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
774
775    ggml_metal_device_event_free(ctx_dev, ev);
776
777    delete event;
778}
779
780static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
781    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
782
783    ggml_metal_event_t evt = (ggml_metal_event_t)event->context;
784
785    ggml_metal_device_event_synchronize(ctx_dev, evt);
786}
787
788static ggml_backend_device_i ggml_backend_metal_device_i = {
789    /* .get_name             = */ ggml_backend_metal_device_get_name,
790    /* .get_description      = */ ggml_backend_metal_device_get_description,
791    /* .get_memory           = */ ggml_backend_metal_device_get_memory,
792    /* .get_type             = */ ggml_backend_metal_device_get_type,
793    /* .get_props            = */ ggml_backend_metal_device_get_props,
794    /* .init_backend         = */ ggml_backend_metal_device_init_backend,
795    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
796    /* .get_host_buffer_type = */ NULL,
797    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
798    /* .supports_op          = */ ggml_backend_metal_device_supports_op,
799    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
800    /* .offload_op           = */ ggml_backend_metal_device_offload_op,
801    /* .event_new            = */ ggml_backend_metal_device_event_new,
802    /* .event_free           = */ ggml_backend_metal_device_event_free,
803    /* .event_synchronize    = */ ggml_backend_metal_device_event_synchronize,
804};
805
806// backend registry
807
808struct ggml_backend_metal_reg {
809    std::vector<ggml_backend_dev_t> devices;
810};
811
812typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t;
813
814static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) {
815    ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg;
816
817    return ctx;
818}
819
820static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) {
821    delete ctx;
822}
823
824struct ggml_backend_metal_reg_deleter {
825    void operator()(ggml_backend_metal_reg_t ctx) {
826        ggml_backend_metal_reg_free(ctx);
827    }
828};
829
830typedef std::unique_ptr<struct ggml_backend_metal_reg, ggml_backend_metal_reg_deleter> ggml_backend_metal_reg_ptr;
831
832static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
833    return GGML_METAL_NAME;
834
835    GGML_UNUSED(reg);
836}
837
838static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
839    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
840    return ctx->devices.size();
841}
842
843static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
844    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
845    GGML_ASSERT(index < ctx->devices.size());
846    return ctx->devices[index];
847}
848
849static ggml_backend_feature g_ggml_backend_metal_features[] = {
850#if defined(GGML_METAL_EMBED_LIBRARY)
851    { "EMBED_LIBRARY", "1" },
852#endif
853    { NULL, NULL },
854};
855
856static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {
857    return g_ggml_backend_metal_features;
858
859    GGML_UNUSED(reg);
860}
861
862static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
863    if (strcmp(name, "ggml_backend_get_features") == 0) {
864        return (void *)ggml_backend_metal_get_features;
865    }
866
867    return NULL;
868
869    GGML_UNUSED(reg);
870}
871
872static ggml_backend_reg_i ggml_backend_metal_reg_i = {
873    /* .get_name         = */ ggml_backend_metal_reg_get_name,
874    /* .get_device_count = */ ggml_backend_metal_reg_device_count,
875    /* .get_device       = */ ggml_backend_metal_reg_device_get,
876    /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
877};
878
879static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) {
880    return new ggml_backend_device {
881        /* .iface   = */ ggml_backend_metal_device_i,
882        /* .reg     = */ reg,
883        /* .context = */ ggml_metal_device_get(device),
884    };
885}
886
887static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) {
888    delete dev;
889}
890
891struct ggml_backend_device_deleter {
892    void operator()(ggml_backend_dev_t ctx) {
893        ggml_backend_metal_device_free(ctx);
894    }
895};
896
897typedef std::unique_ptr<ggml_backend_device, ggml_backend_device_deleter> ggml_backend_device_ptr;
898
899ggml_backend_reg_t ggml_backend_metal_reg(void) {
900    static ggml_backend_reg reg;
901    static bool initialized = false;
902
903    {
904        static std::mutex mutex;
905        std::lock_guard<std::mutex> lock(mutex);
906
907        const char * env = getenv("GGML_METAL_DEVICES");
908        if (env) {
909            g_devices = atoi(env);
910        }
911
912        static std::vector<ggml_backend_device_ptr> devs;
913
914        if (!initialized) {
915            static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
916
917            for (int i = 0; i < g_devices; ++i) {
918                auto * dev = ggml_backend_metal_device_init(&reg, i);
919                devs.emplace_back(dev);
920
921                reg_ctx->devices.push_back(dev);
922            }
923
924            reg = {
925                /* .api_version = */ GGML_BACKEND_API_VERSION,
926                /* .iface       = */ ggml_backend_metal_reg_i,
927                /* .context     = */ reg_ctx.get(),
928            };
929        }
930
931        initialized = true;
932    }
933
934    return &reg;
935}
936
937GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)