1#import "ggml-metal-context.h"
  2
  3#import "ggml-impl.h"
  4#import "ggml-backend-impl.h"
  5
  6#import "ggml-metal-impl.h"
  7#import "ggml-metal-common.h"
  8#import "ggml-metal-ops.h"
  9
 10#import <Foundation/Foundation.h>
 11
 12#import <Metal/Metal.h>
 13
 14#undef MIN
 15#undef MAX
 16#define MIN(a, b) ((a) < (b) ? (a) : (b))
 17#define MAX(a, b) ((a) > (b) ? (a) : (b))
 18
 19// max number of MTLCommandBuffer used to submit a graph for processing
 20#define GGML_METAL_MAX_COMMAND_BUFFERS 8
 21
 22struct ggml_metal_command_buffer {
 23    id<MTLCommandBuffer> obj;
 24};
 25
 26struct ggml_metal {
 27    char name[128];
 28
 29    ggml_metal_device_t  dev;
 30    ggml_metal_library_t lib;
 31
 32    ggml_metal_event_t ev_cpy; // for async copies
 33
 34    dispatch_queue_t d_queue;
 35
 36    // additional, inference-time compiled pipelines
 37    ggml_metal_pipelines_t pipelines_ext;
 38
 39    bool use_fusion;
 40    bool use_concurrency;
 41    bool use_graph_optimize;
 42
 43    int debug_graph;
 44    int debug_fusion;
 45
 46    // how many times a given op was fused
 47    uint64_t fuse_cnt[GGML_OP_COUNT];
 48
 49    // capture state
 50    bool capture_next_compute;
 51    bool capture_started;
 52
 53    id<MTLCaptureScope> capture_scope;
 54
 55    // command buffer state
 56    int n_cb;           // number of extra threads used to submit the command buffers
 57    int n_nodes_0;      // number of nodes submitted by the main thread
 58    int n_nodes_1;      // remaining number of nodes submitted by the n_cb threads
 59    int n_nodes_per_cb;
 60
 61    struct ggml_cgraph * gf;
 62
 63    // the callback given to the thread pool
 64    void (^encode_async)(size_t ith);
 65
 66    // n_cb command buffers + 1 used by the main thread
 67    struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
 68
 69    // extra command buffers for things like getting, setting and copying tensors
 70    NSMutableArray * cmd_bufs_ext;
 71
 72    // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
 73    id<MTLCommandBuffer> cmd_buf_last;
 74
 75    // abort ggml_metal_graph_compute if callback returns true
 76    ggml_abort_callback abort_callback;
 77    void *              abort_callback_data;
 78};
 79
 80ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
 81    GGML_LOG_INFO("%s: allocating\n", __func__);
 82
 83#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
 84    // Show all the Metal device instances in the system
 85    NSArray * devices = MTLCopyAllDevices();
 86    for (id<MTLDevice> device in devices) {
 87        GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
 88    }
 89    [devices release]; // since it was created by a *Copy* C method
 90#endif
 91
 92    // init context
 93    ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
 94
 95    id<MTLDevice> device = ggml_metal_device_get_obj(dev);
 96
 97    GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 98
 99    // TODO: would it be better to have one queue for the backend and one queue for the device?
100    //       the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
101    //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
102    id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev);
103    if (queue == nil) {
104        GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
105        return NULL;
106    }
107
108    res->dev = dev;
109    res->lib = ggml_metal_device_get_library(dev);
110    if (res->lib == NULL) {
111        GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
112        GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
113
114        res->lib = ggml_metal_library_init(dev);
115        if (res->lib == NULL) {
116            GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
117
118            free(res);
119
120            return NULL;
121        }
122    }
123
124    res->ev_cpy = ggml_metal_device_event_init(dev);
125
126    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
127
128    snprintf(res->name, sizeof(res->name), "%s", props_dev->name);
129
130    res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
131
132    res->use_fusion      = getenv("GGML_METAL_FUSION_DISABLE") == nil;
133    res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
134
135    {
136        const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
137        res->debug_graph = val ? atoi(val) : 0;
138    }
139
140    {
141        const char * val = getenv("GGML_METAL_FUSION_DEBUG");
142        res->debug_fusion = val ? atoi(val) : 0;
143    }
144
145    res->use_graph_optimize = true;
146
147    if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
148        res->use_graph_optimize = false;
149    }
150
151    memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
152
153    GGML_LOG_INFO("%s: use fusion         = %s\n", __func__, res->use_fusion         ? "true" : "false");
154    GGML_LOG_INFO("%s: use concurrency    = %s\n", __func__, res->use_concurrency    ? "true" : "false");
155    GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
156
157    res->capture_next_compute = false;
158    res->capture_started = false;
159    res->capture_scope = nil;
160
161    res->gf = nil;
162    res->encode_async = nil;
163    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
164        res->cmd_bufs[i].obj = nil;
165    }
166
167    res->cmd_bufs_ext = [[NSMutableArray alloc] init];
168
169    res->cmd_buf_last = nil;
170
171    res->pipelines_ext = ggml_metal_pipelines_init();
172
173    return res;
174}
175
176void ggml_metal_free(ggml_metal_t ctx) {
177    GGML_LOG_INFO("%s: deallocating\n", __func__);
178
179    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
180        if (ctx->cmd_bufs[i].obj) {
181            [ctx->cmd_bufs[i].obj release];
182        }
183    }
184
185    for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
186        if (ctx->cmd_bufs_ext[i]) {
187            [ctx->cmd_bufs_ext[i] release];
188        }
189    }
190
191    [ctx->cmd_bufs_ext removeAllObjects];
192    [ctx->cmd_bufs_ext release];
193
194    if (ctx->pipelines_ext) {
195        ggml_metal_pipelines_free(ctx->pipelines_ext);
196        ctx->pipelines_ext = nil;
197    }
198
199    if (ctx->debug_fusion > 0) {
200        GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
201        for (int i = 0; i < GGML_OP_COUNT; i++) {
202            if (ctx->fuse_cnt[i] == 0) {
203                continue;
204            }
205
206            // note: cannot use ggml_log here
207            GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
208        }
209    }
210
211    Block_release(ctx->encode_async);
212
213    //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
214
215    dispatch_release(ctx->d_queue);
216
217    ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);
218
219    free(ctx);
220}
221
222const char * ggml_metal_get_name(ggml_metal_t ctx) {
223    return ctx->name;
224}
225
226void ggml_metal_synchronize(ggml_metal_t ctx) {
227    // wait for any backend operations to finish
228    if (ctx->cmd_buf_last) {
229        [ctx->cmd_buf_last waitUntilCompleted];
230        ctx->cmd_buf_last = nil;
231    }
232
233    // check status of all command buffers
234    {
235        const int n_cb = ctx->n_cb;
236
237        for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {
238            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
239            if (!cmd_buf) {
240                continue;
241            }
242
243            MTLCommandBufferStatus status = [cmd_buf status];
244            if (status != MTLCommandBufferStatusCompleted) {
245                GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status);
246                if (status == MTLCommandBufferStatusError) {
247                    GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
248                }
249                GGML_ABORT("fatal error");
250            }
251        }
252    }
253
254    // release any completed extra command buffers
255    if (ctx->cmd_bufs_ext.count > 0) {
256        for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
257            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
258
259            MTLCommandBufferStatus status = [cmd_buf status];
260            if (status != MTLCommandBufferStatusCompleted) {
261                GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
262                if (status == MTLCommandBufferStatusError) {
263                    GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
264                }
265                GGML_ABORT("fatal error");
266            }
267
268            [cmd_buf release];
269        }
270
271        [ctx->cmd_bufs_ext removeAllObjects];
272    }
273}
274
275static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
276    if (!t) {
277        return (struct ggml_metal_buffer_id) { nil, 0 };
278    }
279
280    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
281
282    return ggml_metal_buffer_get_id(buffer->context, t);
283}
284
285void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
286    @autoreleasepool {
287        // wrap the source data into a Metal buffer
288        id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
289        id<MTLBuffer> buf_src = [device newBufferWithBytes:data
290                                                    length:size
291                                                   options:MTLResourceStorageModeShared];
292
293        GGML_ASSERT(buf_src);
294
295        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
296        if (bid_dst.metal == nil) {
297            GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
298        }
299
300        bid_dst.offs += offset;
301
302        // queue the copy operation into the queue of the Metal context
303        // this will be queued at the end, after any currently ongoing GPU operations
304        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
305        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
306        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
307
308        [encoder copyFromBuffer:buf_src
309                   sourceOffset:0
310                       toBuffer:bid_dst.metal
311              destinationOffset:bid_dst.offs
312                           size:size];
313
314        [encoder endEncoding];
315        [cmd_buf commit];
316        [buf_src release];
317
318        // do not wait here for completion
319        //[cmd_buf waitUntilCompleted];
320
321        // instead, remember a reference to the command buffer and wait for it later if needed
322        [ctx->cmd_bufs_ext addObject:cmd_buf];
323        ctx->cmd_buf_last = cmd_buf;
324
325        [cmd_buf retain];
326    }
327}
328
329void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
330    @autoreleasepool {
331        id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
332        id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data
333                                                          length:size
334                                                         options:MTLResourceStorageModeShared
335                                                     deallocator:nil];
336
337        GGML_ASSERT(buf_dst);
338
339        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
340        if (bid_src.metal == nil) {
341            GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
342        }
343
344        bid_src.offs += offset;
345
346        // queue the copy operation into the queue of the Metal context
347        // this will be queued at the end, after any currently ongoing GPU operations
348        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
349        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
350        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
351
352        [encoder copyFromBuffer:bid_src.metal
353                   sourceOffset:bid_src.offs
354                       toBuffer:buf_dst
355              destinationOffset:0
356                           size:size];
357
358        [encoder endEncoding];
359        [cmd_buf commit];
360        [buf_dst release];
361
362        // do not wait here for completion
363        //[cmd_buf waitUntilCompleted];
364
365        // instead, remember a reference to the command buffer and wait for it later if needed
366        [ctx->cmd_bufs_ext addObject:cmd_buf];
367        ctx->cmd_buf_last = cmd_buf;
368
369        [cmd_buf retain];
370    }
371}
372
373bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
374    @autoreleasepool {
375        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);
376        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);
377
378        if (bid_src.metal == nil || bid_dst.metal == nil) {
379            return false;
380        }
381
382        // queue the copy operation into the Metal context
383        // this will be queued at the end, after any currently ongoing GPU operations
384        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev);
385        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
386        id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
387
388        [encoder copyFromBuffer:bid_src.metal
389                   sourceOffset:bid_src.offs
390                       toBuffer:bid_dst.metal
391              destinationOffset:bid_dst.offs
392                           size:ggml_nbytes(src)];
393
394        [encoder endEncoding];
395
396        ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
397        ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
398
399        [cmd_buf commit];
400
401        // do not wait here for completion
402        //[cmd_buf waitUntilCompleted];
403
404        // instead, remember a reference to the command buffer and wait for it later if needed
405        [ctx_src->cmd_bufs_ext addObject:cmd_buf];
406        ctx_src->cmd_buf_last = cmd_buf;
407
408        [cmd_buf retain];
409
410        ggml_metal_event_wait(ctx_dst, ev_cpy);
411
412        return true;
413    }
414}
415
416enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
417    // number of nodes encoded by the main thread (empirically determined)
418    const int n_main = MAX(64, 0.1*gf->n_nodes);
419
420    // number of threads in addition to the main thread
421    const int n_cb = ctx->n_cb;
422
423    // keep the memory wired
424    ggml_metal_device_rsets_keep_alive(ctx->dev);
425
426    // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
427    // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
428    // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
429    // each thread creates it's own command buffer and enqueues the ops in parallel
430    //
431    // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
432
433    @autoreleasepool {
434        ctx->gf = gf;
435
436        ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
437        ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
438
439        ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
440
441        const bool use_capture = ctx->capture_next_compute;
442        if (use_capture) {
443            ctx->capture_next_compute = false;
444
445            // make sure all previous computations have finished before starting the capture
446            if (ctx->cmd_buf_last) {
447                [ctx->cmd_buf_last waitUntilCompleted];
448                ctx->cmd_buf_last = nil;
449            }
450
451            if (!ctx->capture_started) {
452                // create capture scope
453                id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
454                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
455
456                MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
457                descriptor.captureObject = ctx->capture_scope;
458                descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
459                descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
460
461                NSError * error = nil;
462                if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
463                    GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
464                } else {
465                    [ctx->capture_scope beginScope];
466                    ctx->capture_started = true;
467                }
468            }
469        }
470
471        // short-hand
472        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
473
474        // the main thread commits the first few commands immediately
475        // cmd_buf[n_cb]
476        {
477            id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
478            [cmd_buf retain];
479
480            if (ctx->cmd_bufs[n_cb].obj) {
481                [ctx->cmd_bufs[n_cb].obj release];
482            }
483            ctx->cmd_bufs[n_cb].obj = cmd_buf;
484
485            [cmd_buf enqueue];
486
487            ctx->encode_async(n_cb);
488        }
489
490        // remember the command buffer for the next iteration
491        ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
492
493        // prepare the rest of the command buffers asynchronously (optional)
494        // cmd_buf[0.. n_cb)
495        for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
496            id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
497            [cmd_buf retain];
498
499            if (ctx->cmd_bufs[cb_idx].obj) {
500                [ctx->cmd_bufs[cb_idx].obj release];
501            }
502            ctx->cmd_bufs[cb_idx].obj = cmd_buf;
503
504            // always enqueue the first two command buffers
505            // enqueue all of the command buffers if we don't need to abort
506            if (cb_idx < 2 || ctx->abort_callback == NULL) {
507                [cmd_buf enqueue];
508
509                // update the pointer to the last queued command buffer
510                // this is needed to implement synchronize()
511                ctx->cmd_buf_last = cmd_buf;
512            }
513        }
514
515        dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
516
517        // for debugging: block until graph is computed
518        //[ctx->cmd_buf_last waitUntilCompleted];
519
520        // enter here only when capturing in order to wait for all computation to finish
521        // otherwise, we leave the graph to compute asynchronously
522        if (!use_capture && ctx->capture_started) {
523            // wait for completion and check status of each command buffer
524            // needed to detect if the device ran out-of-memory for example (#1881)
525            {
526                id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
527                [cmd_buf waitUntilCompleted];
528
529                MTLCommandBufferStatus status = [cmd_buf status];
530                if (status != MTLCommandBufferStatusCompleted) {
531                    GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
532                    if (status == MTLCommandBufferStatusError) {
533                        GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
534                    }
535
536                    return GGML_STATUS_FAILED;
537                }
538            }
539
540            for (int i = 0; i < n_cb; ++i) {
541                id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
542                [cmd_buf waitUntilCompleted];
543
544                MTLCommandBufferStatus status = [cmd_buf status];
545                if (status != MTLCommandBufferStatusCompleted) {
546                    GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
547                    if (status == MTLCommandBufferStatusError) {
548                        GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
549                    }
550
551                    return GGML_STATUS_FAILED;
552                }
553
554                id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
555                if (!next_buffer) {
556                    continue;
557                }
558
559                const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
560                if (next_queued) {
561                    continue;
562                }
563
564                if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
565                    GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
566                    return GGML_STATUS_ABORTED;
567                }
568
569                [next_buffer commit];
570            }
571
572            [ctx->capture_scope endScope];
573            [[MTLCaptureManager sharedCaptureManager] stopCapture];
574        }
575    }
576
577    return GGML_STATUS_SUCCESS;
578}
579
580void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
581    //const int64_t t_start = ggml_time_us();
582
583    if (ctx->use_graph_optimize) {
584        ggml_graph_optimize(gf);
585    }
586
587    //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
588}
589
590void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {
591    @autoreleasepool {
592        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
593        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
594
595        ggml_metal_event_encode_signal(ev, cmd_buf);
596
597        [cmd_buf commit];
598
599        [ctx->cmd_bufs_ext addObject:cmd_buf];
600        ctx->cmd_buf_last = cmd_buf;
601
602        [cmd_buf retain];
603    }
604}
605
606void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {
607    @autoreleasepool {
608        id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
609        id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
610
611        ggml_metal_event_encode_wait(ev, cmd_buf);
612
613        [cmd_buf commit];
614
615        [ctx->cmd_bufs_ext addObject:cmd_buf];
616        ctx->cmd_buf_last = cmd_buf;
617
618        [cmd_buf retain];
619    }
620}
621
622ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {
623    return ctx->ev_cpy;
624}
625
626void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
627    if (ctx->n_cb != n_cb) {
628        ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
629
630        if (ctx->n_cb > 2) {
631            GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
632        }
633    }
634
635    if (ctx->encode_async) {
636        Block_release(ctx->encode_async);
637    }
638
639    ctx->encode_async = Block_copy(^(size_t iter) {
640        const int cb_idx = iter;
641        const int n_cb_l = ctx->n_cb;
642
643        const int n_nodes_0 = ctx->n_nodes_0;
644        const int n_nodes_1 = ctx->n_nodes_1;
645
646        const int n_nodes_per_cb = ctx->n_nodes_per_cb;
647
648        int idx_start = 0;
649        int idx_end   = n_nodes_0;
650
651        if (cb_idx < n_cb_l) {
652            idx_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
653            idx_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
654        }
655
656        id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
657
658        ggml_metal_op_t ctx_op = ggml_metal_op_init(
659            ctx->dev,
660            cmd_buf,
661            ctx->gf,
662            idx_start,
663            idx_end,
664            ctx->use_fusion,
665            ctx->use_concurrency,
666            ctx->capture_next_compute,
667            ctx->debug_graph,
668            ctx->debug_fusion);
669
670        for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
671            const int res = ggml_metal_op_encode(ctx_op, idx);
672            if (res == 0) {
673                break;
674            }
675
676            idx += res - 1;
677        }
678
679        ggml_metal_op_free(ctx_op);
680
681        if (cb_idx < 2 || ctx->abort_callback == NULL) {
682            [cmd_buf commit];
683        }
684    });
685}
686
687void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
688    ctx->abort_callback = abort_callback;
689    ctx->abort_callback_data = user_data;
690}
691
692bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
693    GGML_ASSERT(ctx->dev != nil);
694
695    id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
696
697    return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
698}
699
700void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
701    ctx->capture_next_compute = true;
702}