summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m')
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m702
1 files changed, 702 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
new file mode 100644
index 0000000..5d3a8ce
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
@@ -0,0 +1,702 @@
1#import "ggml-metal-context.h"
2
3#import "ggml-impl.h"
4#import "ggml-backend-impl.h"
5
6#import "ggml-metal-impl.h"
7#import "ggml-metal-common.h"
8#import "ggml-metal-ops.h"
9
10#import <Foundation/Foundation.h>
11
12#import <Metal/Metal.h>
13
14#undef MIN
15#undef MAX
16#define MIN(a, b) ((a) < (b) ? (a) : (b))
17#define MAX(a, b) ((a) > (b) ? (a) : (b))
18
19// max number of MTLCommandBuffer used to submit a graph for processing
20#define GGML_METAL_MAX_COMMAND_BUFFERS 8
21
22struct ggml_metal_command_buffer {
23 id<MTLCommandBuffer> obj;
24};
25
26struct ggml_metal {
27 char name[128];
28
29 ggml_metal_device_t dev;
30 ggml_metal_library_t lib;
31
32 ggml_metal_event_t ev_cpy; // for async copies
33
34 dispatch_queue_t d_queue;
35
36 // additional, inference-time compiled pipelines
37 ggml_metal_pipelines_t pipelines_ext;
38
39 bool use_fusion;
40 bool use_concurrency;
41 bool use_graph_optimize;
42
43 int debug_graph;
44 int debug_fusion;
45
46 // how many times a given op was fused
47 uint64_t fuse_cnt[GGML_OP_COUNT];
48
49 // capture state
50 bool capture_next_compute;
51 bool capture_started;
52
53 id<MTLCaptureScope> capture_scope;
54
55 // command buffer state
56 int n_cb; // number of extra threads used to submit the command buffers
57 int n_nodes_0; // number of nodes submitted by the main thread
58 int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
59 int n_nodes_per_cb;
60
61 struct ggml_cgraph * gf;
62
63 // the callback given to the thread pool
64 void (^encode_async)(size_t ith);
65
66 // n_cb command buffers + 1 used by the main thread
67 struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
68
69 // extra command buffers for things like getting, setting and copying tensors
70 NSMutableArray * cmd_bufs_ext;
71
72 // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
73 id<MTLCommandBuffer> cmd_buf_last;
74
75 // abort ggml_metal_graph_compute if callback returns true
76 ggml_abort_callback abort_callback;
77 void * abort_callback_data;
78};
79
80ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
81 GGML_LOG_INFO("%s: allocating\n", __func__);
82
83#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
84 // Show all the Metal device instances in the system
85 NSArray * devices = MTLCopyAllDevices();
86 for (id<MTLDevice> device in devices) {
87 GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
88 }
89 [devices release]; // since it was created by a *Copy* C method
90#endif
91
92 // init context
93 ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
94
95 id<MTLDevice> device = ggml_metal_device_get_obj(dev);
96
97 GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
98
99 // TODO: would it be better to have one queue for the backend and one queue for the device?
100 // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
101 //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
102 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev);
103 if (queue == nil) {
104 GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
105 return NULL;
106 }
107
108 res->dev = dev;
109 res->lib = ggml_metal_device_get_library(dev);
110 if (res->lib == NULL) {
111 GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
112 GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
113
114 res->lib = ggml_metal_library_init(dev);
115 if (res->lib == NULL) {
116 GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
117
118 free(res);
119
120 return NULL;
121 }
122 }
123
124 res->ev_cpy = ggml_metal_device_event_init(dev);
125
126 const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
127
128 snprintf(res->name, sizeof(res->name), "%s", props_dev->name);
129
130 res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
131
132 res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
133 res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
134
135 {
136 const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
137 res->debug_graph = val ? atoi(val) : 0;
138 }
139
140 {
141 const char * val = getenv("GGML_METAL_FUSION_DEBUG");
142 res->debug_fusion = val ? atoi(val) : 0;
143 }
144
145 res->use_graph_optimize = true;
146
147 if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
148 res->use_graph_optimize = false;
149 }
150
151 memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
152
153 GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
154 GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
155 GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
156
157 res->capture_next_compute = false;
158 res->capture_started = false;
159 res->capture_scope = nil;
160
161 res->gf = nil;
162 res->encode_async = nil;
163 for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
164 res->cmd_bufs[i].obj = nil;
165 }
166
167 res->cmd_bufs_ext = [[NSMutableArray alloc] init];
168
169 res->cmd_buf_last = nil;
170
171 res->pipelines_ext = ggml_metal_pipelines_init();
172
173 return res;
174}
175
176void ggml_metal_free(ggml_metal_t ctx) {
177 GGML_LOG_INFO("%s: deallocating\n", __func__);
178
179 for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
180 if (ctx->cmd_bufs[i].obj) {
181 [ctx->cmd_bufs[i].obj release];
182 }
183 }
184
185 for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
186 if (ctx->cmd_bufs_ext[i]) {
187 [ctx->cmd_bufs_ext[i] release];
188 }
189 }
190
191 [ctx->cmd_bufs_ext removeAllObjects];
192 [ctx->cmd_bufs_ext release];
193
194 if (ctx->pipelines_ext) {
195 ggml_metal_pipelines_free(ctx->pipelines_ext);
196 ctx->pipelines_ext = nil;
197 }
198
199 if (ctx->debug_fusion > 0) {
200 GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
201 for (int i = 0; i < GGML_OP_COUNT; i++) {
202 if (ctx->fuse_cnt[i] == 0) {
203 continue;
204 }
205
206 // note: cannot use ggml_log here
207 GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
208 }
209 }
210
211 Block_release(ctx->encode_async);
212
213 //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
214
215 dispatch_release(ctx->d_queue);
216
217 ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);
218
219 free(ctx);
220}
221
222const char * ggml_metal_get_name(ggml_metal_t ctx) {
223 return ctx->name;
224}
225
226void ggml_metal_synchronize(ggml_metal_t ctx) {
227 // wait for any backend operations to finish
228 if (ctx->cmd_buf_last) {
229 [ctx->cmd_buf_last waitUntilCompleted];
230 ctx->cmd_buf_last = nil;
231 }
232
233 // check status of all command buffers
234 {
235 const int n_cb = ctx->n_cb;
236
237 for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {
238 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
239 if (!cmd_buf) {
240 continue;
241 }
242
243 MTLCommandBufferStatus status = [cmd_buf status];
244 if (status != MTLCommandBufferStatusCompleted) {
245 GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status);
246 if (status == MTLCommandBufferStatusError) {
247 GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
248 }
249 GGML_ABORT("fatal error");
250 }
251 }
252 }
253
254 // release any completed extra command buffers
255 if (ctx->cmd_bufs_ext.count > 0) {
256 for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
257 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
258
259 MTLCommandBufferStatus status = [cmd_buf status];
260 if (status != MTLCommandBufferStatusCompleted) {
261 GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
262 if (status == MTLCommandBufferStatusError) {
263 GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
264 }
265 GGML_ABORT("fatal error");
266 }
267
268 [cmd_buf release];
269 }
270
271 [ctx->cmd_bufs_ext removeAllObjects];
272 }
273}
274
275static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
276 if (!t) {
277 return (struct ggml_metal_buffer_id) { nil, 0 };
278 }
279
280 ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
281
282 return ggml_metal_buffer_get_id(buffer->context, t);
283}
284
285void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
286 @autoreleasepool {
287 // wrap the source data into a Metal buffer
288 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
289 id<MTLBuffer> buf_src = [device newBufferWithBytes:data
290 length:size
291 options:MTLResourceStorageModeShared];
292
293 GGML_ASSERT(buf_src);
294
295 struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
296 if (bid_dst.metal == nil) {
297 GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
298 }
299
300 bid_dst.offs += offset;
301
302 // queue the copy operation into the queue of the Metal context
303 // this will be queued at the end, after any currently ongoing GPU operations
304 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
305 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
306 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
307
308 [encoder copyFromBuffer:buf_src
309 sourceOffset:0
310 toBuffer:bid_dst.metal
311 destinationOffset:bid_dst.offs
312 size:size];
313
314 [encoder endEncoding];
315 [cmd_buf commit];
316 [buf_src release];
317
318 // do not wait here for completion
319 //[cmd_buf waitUntilCompleted];
320
321 // instead, remember a reference to the command buffer and wait for it later if needed
322 [ctx->cmd_bufs_ext addObject:cmd_buf];
323 ctx->cmd_buf_last = cmd_buf;
324
325 [cmd_buf retain];
326 }
327}
328
329void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
330 @autoreleasepool {
331 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
332 id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data
333 length:size
334 options:MTLResourceStorageModeShared
335 deallocator:nil];
336
337 GGML_ASSERT(buf_dst);
338
339 struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
340 if (bid_src.metal == nil) {
341 GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
342 }
343
344 bid_src.offs += offset;
345
346 // queue the copy operation into the queue of the Metal context
347 // this will be queued at the end, after any currently ongoing GPU operations
348 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
349 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
350 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
351
352 [encoder copyFromBuffer:bid_src.metal
353 sourceOffset:bid_src.offs
354 toBuffer:buf_dst
355 destinationOffset:0
356 size:size];
357
358 [encoder endEncoding];
359 [cmd_buf commit];
360 [buf_dst release];
361
362 // do not wait here for completion
363 //[cmd_buf waitUntilCompleted];
364
365 // instead, remember a reference to the command buffer and wait for it later if needed
366 [ctx->cmd_bufs_ext addObject:cmd_buf];
367 ctx->cmd_buf_last = cmd_buf;
368
369 [cmd_buf retain];
370 }
371}
372
373bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
374 @autoreleasepool {
375 struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);
376 struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);
377
378 if (bid_src.metal == nil || bid_dst.metal == nil) {
379 return false;
380 }
381
382 // queue the copy operation into the Metal context
383 // this will be queued at the end, after any currently ongoing GPU operations
384 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev);
385 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
386 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
387
388 [encoder copyFromBuffer:bid_src.metal
389 sourceOffset:bid_src.offs
390 toBuffer:bid_dst.metal
391 destinationOffset:bid_dst.offs
392 size:ggml_nbytes(src)];
393
394 [encoder endEncoding];
395
396 ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
397 ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
398
399 [cmd_buf commit];
400
401 // do not wait here for completion
402 //[cmd_buf waitUntilCompleted];
403
404 // instead, remember a reference to the command buffer and wait for it later if needed
405 [ctx_src->cmd_bufs_ext addObject:cmd_buf];
406 ctx_src->cmd_buf_last = cmd_buf;
407
408 [cmd_buf retain];
409
410 ggml_metal_event_wait(ctx_dst, ev_cpy);
411
412 return true;
413 }
414}
415
416enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
417 // number of nodes encoded by the main thread (empirically determined)
418 const int n_main = MAX(64, 0.1*gf->n_nodes);
419
420 // number of threads in addition to the main thread
421 const int n_cb = ctx->n_cb;
422
423 // keep the memory wired
424 ggml_metal_device_rsets_keep_alive(ctx->dev);
425
426 // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
427 // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
428 // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
429 // each thread creates it's own command buffer and enqueues the ops in parallel
430 //
431 // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
432
433 @autoreleasepool {
434 ctx->gf = gf;
435
436 ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
437 ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
438
439 ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
440
441 const bool use_capture = ctx->capture_next_compute;
442 if (use_capture) {
443 ctx->capture_next_compute = false;
444
445 // make sure all previous computations have finished before starting the capture
446 if (ctx->cmd_buf_last) {
447 [ctx->cmd_buf_last waitUntilCompleted];
448 ctx->cmd_buf_last = nil;
449 }
450
451 if (!ctx->capture_started) {
452 // create capture scope
453 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
454 ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
455
456 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
457 descriptor.captureObject = ctx->capture_scope;
458 descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
459 descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
460
461 NSError * error = nil;
462 if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
463 GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
464 } else {
465 [ctx->capture_scope beginScope];
466 ctx->capture_started = true;
467 }
468 }
469 }
470
471 // short-hand
472 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
473
474 // the main thread commits the first few commands immediately
475 // cmd_buf[n_cb]
476 {
477 id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
478 [cmd_buf retain];
479
480 if (ctx->cmd_bufs[n_cb].obj) {
481 [ctx->cmd_bufs[n_cb].obj release];
482 }
483 ctx->cmd_bufs[n_cb].obj = cmd_buf;
484
485 [cmd_buf enqueue];
486
487 ctx->encode_async(n_cb);
488 }
489
490 // remember the command buffer for the next iteration
491 ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
492
493 // prepare the rest of the command buffers asynchronously (optional)
494 // cmd_buf[0.. n_cb)
495 for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
496 id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
497 [cmd_buf retain];
498
499 if (ctx->cmd_bufs[cb_idx].obj) {
500 [ctx->cmd_bufs[cb_idx].obj release];
501 }
502 ctx->cmd_bufs[cb_idx].obj = cmd_buf;
503
504 // always enqueue the first two command buffers
505 // enqueue all of the command buffers if we don't need to abort
506 if (cb_idx < 2 || ctx->abort_callback == NULL) {
507 [cmd_buf enqueue];
508
509 // update the pointer to the last queued command buffer
510 // this is needed to implement synchronize()
511 ctx->cmd_buf_last = cmd_buf;
512 }
513 }
514
515 dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
516
517 // for debugging: block until graph is computed
518 //[ctx->cmd_buf_last waitUntilCompleted];
519
520 // enter here only when capturing in order to wait for all computation to finish
521 // otherwise, we leave the graph to compute asynchronously
522 if (!use_capture && ctx->capture_started) {
523 // wait for completion and check status of each command buffer
524 // needed to detect if the device ran out-of-memory for example (#1881)
525 {
526 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
527 [cmd_buf waitUntilCompleted];
528
529 MTLCommandBufferStatus status = [cmd_buf status];
530 if (status != MTLCommandBufferStatusCompleted) {
531 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
532 if (status == MTLCommandBufferStatusError) {
533 GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
534 }
535
536 return GGML_STATUS_FAILED;
537 }
538 }
539
540 for (int i = 0; i < n_cb; ++i) {
541 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
542 [cmd_buf waitUntilCompleted];
543
544 MTLCommandBufferStatus status = [cmd_buf status];
545 if (status != MTLCommandBufferStatusCompleted) {
546 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
547 if (status == MTLCommandBufferStatusError) {
548 GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
549 }
550
551 return GGML_STATUS_FAILED;
552 }
553
554 id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
555 if (!next_buffer) {
556 continue;
557 }
558
559 const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
560 if (next_queued) {
561 continue;
562 }
563
564 if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
565 GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
566 return GGML_STATUS_ABORTED;
567 }
568
569 [next_buffer commit];
570 }
571
572 [ctx->capture_scope endScope];
573 [[MTLCaptureManager sharedCaptureManager] stopCapture];
574 }
575 }
576
577 return GGML_STATUS_SUCCESS;
578}
579
580void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
581 //const int64_t t_start = ggml_time_us();
582
583 if (ctx->use_graph_optimize) {
584 ggml_graph_optimize(gf);
585 }
586
587 //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
588}
589
590void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {
591 @autoreleasepool {
592 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
593 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
594
595 ggml_metal_event_encode_signal(ev, cmd_buf);
596
597 [cmd_buf commit];
598
599 [ctx->cmd_bufs_ext addObject:cmd_buf];
600 ctx->cmd_buf_last = cmd_buf;
601
602 [cmd_buf retain];
603 }
604}
605
606void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {
607 @autoreleasepool {
608 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
609 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
610
611 ggml_metal_event_encode_wait(ev, cmd_buf);
612
613 [cmd_buf commit];
614
615 [ctx->cmd_bufs_ext addObject:cmd_buf];
616 ctx->cmd_buf_last = cmd_buf;
617
618 [cmd_buf retain];
619 }
620}
621
622ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {
623 return ctx->ev_cpy;
624}
625
626void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
627 if (ctx->n_cb != n_cb) {
628 ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
629
630 if (ctx->n_cb > 2) {
631 GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
632 }
633 }
634
635 if (ctx->encode_async) {
636 Block_release(ctx->encode_async);
637 }
638
639 ctx->encode_async = Block_copy(^(size_t iter) {
640 const int cb_idx = iter;
641 const int n_cb_l = ctx->n_cb;
642
643 const int n_nodes_0 = ctx->n_nodes_0;
644 const int n_nodes_1 = ctx->n_nodes_1;
645
646 const int n_nodes_per_cb = ctx->n_nodes_per_cb;
647
648 int idx_start = 0;
649 int idx_end = n_nodes_0;
650
651 if (cb_idx < n_cb_l) {
652 idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
653 idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
654 }
655
656 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
657
658 ggml_metal_op_t ctx_op = ggml_metal_op_init(
659 ctx->dev,
660 cmd_buf,
661 ctx->gf,
662 idx_start,
663 idx_end,
664 ctx->use_fusion,
665 ctx->use_concurrency,
666 ctx->capture_next_compute,
667 ctx->debug_graph,
668 ctx->debug_fusion);
669
670 for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
671 const int res = ggml_metal_op_encode(ctx_op, idx);
672 if (res == 0) {
673 break;
674 }
675
676 idx += res - 1;
677 }
678
679 ggml_metal_op_free(ctx_op);
680
681 if (cb_idx < 2 || ctx->abort_callback == NULL) {
682 [cmd_buf commit];
683 }
684 });
685}
686
687void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
688 ctx->abort_callback = abort_callback;
689 ctx->abort_callback_data = user_data;
690}
691
692bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
693 GGML_ASSERT(ctx->dev != nil);
694
695 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
696
697 return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
698}
699
700void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
701 ctx->capture_next_compute = true;
702}