diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m | 702 |
1 files changed, 702 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m new file mode 100644 index 0000000..5d3a8ce --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m | |||
| @@ -0,0 +1,702 @@ | |||
| 1 | #import "ggml-metal-context.h" | ||
| 2 | |||
| 3 | #import "ggml-impl.h" | ||
| 4 | #import "ggml-backend-impl.h" | ||
| 5 | |||
| 6 | #import "ggml-metal-impl.h" | ||
| 7 | #import "ggml-metal-common.h" | ||
| 8 | #import "ggml-metal-ops.h" | ||
| 9 | |||
| 10 | #import <Foundation/Foundation.h> | ||
| 11 | |||
| 12 | #import <Metal/Metal.h> | ||
| 13 | |||
| 14 | #undef MIN | ||
| 15 | #undef MAX | ||
| 16 | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||
| 17 | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||
| 18 | |||
| 19 | // max number of MTLCommandBuffer used to submit a graph for processing | ||
| 20 | #define GGML_METAL_MAX_COMMAND_BUFFERS 8 | ||
| 21 | |||
| 22 | struct ggml_metal_command_buffer { | ||
| 23 | id<MTLCommandBuffer> obj; | ||
| 24 | }; | ||
| 25 | |||
| 26 | struct ggml_metal { | ||
| 27 | char name[128]; | ||
| 28 | |||
| 29 | ggml_metal_device_t dev; | ||
| 30 | ggml_metal_library_t lib; | ||
| 31 | |||
| 32 | ggml_metal_event_t ev_cpy; // for async copies | ||
| 33 | |||
| 34 | dispatch_queue_t d_queue; | ||
| 35 | |||
| 36 | // additional, inference-time compiled pipelines | ||
| 37 | ggml_metal_pipelines_t pipelines_ext; | ||
| 38 | |||
| 39 | bool use_fusion; | ||
| 40 | bool use_concurrency; | ||
| 41 | bool use_graph_optimize; | ||
| 42 | |||
| 43 | int debug_graph; | ||
| 44 | int debug_fusion; | ||
| 45 | |||
| 46 | // how many times a given op was fused | ||
| 47 | uint64_t fuse_cnt[GGML_OP_COUNT]; | ||
| 48 | |||
| 49 | // capture state | ||
| 50 | bool capture_next_compute; | ||
| 51 | bool capture_started; | ||
| 52 | |||
| 53 | id<MTLCaptureScope> capture_scope; | ||
| 54 | |||
| 55 | // command buffer state | ||
| 56 | int n_cb; // number of extra threads used to submit the command buffers | ||
| 57 | int n_nodes_0; // number of nodes submitted by the main thread | ||
| 58 | int n_nodes_1; // remaining number of nodes submitted by the n_cb threads | ||
| 59 | int n_nodes_per_cb; | ||
| 60 | |||
| 61 | struct ggml_cgraph * gf; | ||
| 62 | |||
| 63 | // the callback given to the thread pool | ||
| 64 | void (^encode_async)(size_t ith); | ||
| 65 | |||
| 66 | // n_cb command buffers + 1 used by the main thread | ||
| 67 | struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; | ||
| 68 | |||
| 69 | // extra command buffers for things like getting, setting and copying tensors | ||
| 70 | NSMutableArray * cmd_bufs_ext; | ||
| 71 | |||
| 72 | // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend | ||
| 73 | id<MTLCommandBuffer> cmd_buf_last; | ||
| 74 | |||
| 75 | // abort ggml_metal_graph_compute if callback returns true | ||
| 76 | ggml_abort_callback abort_callback; | ||
| 77 | void * abort_callback_data; | ||
| 78 | }; | ||
| 79 | |||
| 80 | ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { | ||
| 81 | GGML_LOG_INFO("%s: allocating\n", __func__); | ||
| 82 | |||
| 83 | #if TARGET_OS_OSX && !GGML_METAL_NDEBUG | ||
| 84 | // Show all the Metal device instances in the system | ||
| 85 | NSArray * devices = MTLCopyAllDevices(); | ||
| 86 | for (id<MTLDevice> device in devices) { | ||
| 87 | GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); | ||
| 88 | } | ||
| 89 | [devices release]; // since it was created by a *Copy* C method | ||
| 90 | #endif | ||
| 91 | |||
| 92 | // init context | ||
| 93 | ggml_metal_t res = calloc(1, sizeof(struct ggml_metal)); | ||
| 94 | |||
| 95 | id<MTLDevice> device = ggml_metal_device_get_obj(dev); | ||
| 96 | |||
| 97 | GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); | ||
| 98 | |||
| 99 | // TODO: would it be better to have one queue for the backend and one queue for the device? | ||
| 100 | // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? | ||
| 101 | //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] | ||
| 102 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev); | ||
| 103 | if (queue == nil) { | ||
| 104 | GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); | ||
| 105 | return NULL; | ||
| 106 | } | ||
| 107 | |||
| 108 | res->dev = dev; | ||
| 109 | res->lib = ggml_metal_device_get_library(dev); | ||
| 110 | if (res->lib == NULL) { | ||
| 111 | GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__); | ||
| 112 | GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__); | ||
| 113 | |||
| 114 | res->lib = ggml_metal_library_init(dev); | ||
| 115 | if (res->lib == NULL) { | ||
| 116 | GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__); | ||
| 117 | |||
| 118 | free(res); | ||
| 119 | |||
| 120 | return NULL; | ||
| 121 | } | ||
| 122 | } | ||
| 123 | |||
| 124 | res->ev_cpy = ggml_metal_device_event_init(dev); | ||
| 125 | |||
| 126 | const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); | ||
| 127 | |||
| 128 | snprintf(res->name, sizeof(res->name), "%s", props_dev->name); | ||
| 129 | |||
| 130 | res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); | ||
| 131 | |||
| 132 | res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; | ||
| 133 | res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; | ||
| 134 | |||
| 135 | { | ||
| 136 | const char * val = getenv("GGML_METAL_GRAPH_DEBUG"); | ||
| 137 | res->debug_graph = val ? atoi(val) : 0; | ||
| 138 | } | ||
| 139 | |||
| 140 | { | ||
| 141 | const char * val = getenv("GGML_METAL_FUSION_DEBUG"); | ||
| 142 | res->debug_fusion = val ? atoi(val) : 0; | ||
| 143 | } | ||
| 144 | |||
| 145 | res->use_graph_optimize = true; | ||
| 146 | |||
| 147 | if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) { | ||
| 148 | res->use_graph_optimize = false; | ||
| 149 | } | ||
| 150 | |||
| 151 | memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt)); | ||
| 152 | |||
| 153 | GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false"); | ||
| 154 | GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); | ||
| 155 | GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); | ||
| 156 | |||
| 157 | res->capture_next_compute = false; | ||
| 158 | res->capture_started = false; | ||
| 159 | res->capture_scope = nil; | ||
| 160 | |||
| 161 | res->gf = nil; | ||
| 162 | res->encode_async = nil; | ||
| 163 | for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { | ||
| 164 | res->cmd_bufs[i].obj = nil; | ||
| 165 | } | ||
| 166 | |||
| 167 | res->cmd_bufs_ext = [[NSMutableArray alloc] init]; | ||
| 168 | |||
| 169 | res->cmd_buf_last = nil; | ||
| 170 | |||
| 171 | res->pipelines_ext = ggml_metal_pipelines_init(); | ||
| 172 | |||
| 173 | return res; | ||
| 174 | } | ||
| 175 | |||
| 176 | void ggml_metal_free(ggml_metal_t ctx) { | ||
| 177 | GGML_LOG_INFO("%s: deallocating\n", __func__); | ||
| 178 | |||
| 179 | for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { | ||
| 180 | if (ctx->cmd_bufs[i].obj) { | ||
| 181 | [ctx->cmd_bufs[i].obj release]; | ||
| 182 | } | ||
| 183 | } | ||
| 184 | |||
| 185 | for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) { | ||
| 186 | if (ctx->cmd_bufs_ext[i]) { | ||
| 187 | [ctx->cmd_bufs_ext[i] release]; | ||
| 188 | } | ||
| 189 | } | ||
| 190 | |||
| 191 | [ctx->cmd_bufs_ext removeAllObjects]; | ||
| 192 | [ctx->cmd_bufs_ext release]; | ||
| 193 | |||
| 194 | if (ctx->pipelines_ext) { | ||
| 195 | ggml_metal_pipelines_free(ctx->pipelines_ext); | ||
| 196 | ctx->pipelines_ext = nil; | ||
| 197 | } | ||
| 198 | |||
| 199 | if (ctx->debug_fusion > 0) { | ||
| 200 | GGML_LOG_DEBUG("%s: fusion stats:\n", __func__); | ||
| 201 | for (int i = 0; i < GGML_OP_COUNT; i++) { | ||
| 202 | if (ctx->fuse_cnt[i] == 0) { | ||
| 203 | continue; | ||
| 204 | } | ||
| 205 | |||
| 206 | // note: cannot use ggml_log here | ||
| 207 | GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); | ||
| 208 | } | ||
| 209 | } | ||
| 210 | |||
| 211 | Block_release(ctx->encode_async); | ||
| 212 | |||
| 213 | //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] | ||
| 214 | |||
| 215 | dispatch_release(ctx->d_queue); | ||
| 216 | |||
| 217 | ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy); | ||
| 218 | |||
| 219 | free(ctx); | ||
| 220 | } | ||
| 221 | |||
| 222 | const char * ggml_metal_get_name(ggml_metal_t ctx) { | ||
| 223 | return ctx->name; | ||
| 224 | } | ||
| 225 | |||
| 226 | void ggml_metal_synchronize(ggml_metal_t ctx) { | ||
| 227 | // wait for any backend operations to finish | ||
| 228 | if (ctx->cmd_buf_last) { | ||
| 229 | [ctx->cmd_buf_last waitUntilCompleted]; | ||
| 230 | ctx->cmd_buf_last = nil; | ||
| 231 | } | ||
| 232 | |||
| 233 | // check status of all command buffers | ||
| 234 | { | ||
| 235 | const int n_cb = ctx->n_cb; | ||
| 236 | |||
| 237 | for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) { | ||
| 238 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj; | ||
| 239 | if (!cmd_buf) { | ||
| 240 | continue; | ||
| 241 | } | ||
| 242 | |||
| 243 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 244 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 245 | GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status); | ||
| 246 | if (status == MTLCommandBufferStatusError) { | ||
| 247 | GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 248 | } | ||
| 249 | GGML_ABORT("fatal error"); | ||
| 250 | } | ||
| 251 | } | ||
| 252 | } | ||
| 253 | |||
| 254 | // release any completed extra command buffers | ||
| 255 | if (ctx->cmd_bufs_ext.count > 0) { | ||
| 256 | for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { | ||
| 257 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i]; | ||
| 258 | |||
| 259 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 260 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 261 | GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); | ||
| 262 | if (status == MTLCommandBufferStatusError) { | ||
| 263 | GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 264 | } | ||
| 265 | GGML_ABORT("fatal error"); | ||
| 266 | } | ||
| 267 | |||
| 268 | [cmd_buf release]; | ||
| 269 | } | ||
| 270 | |||
| 271 | [ctx->cmd_bufs_ext removeAllObjects]; | ||
| 272 | } | ||
| 273 | } | ||
| 274 | |||
| 275 | static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) { | ||
| 276 | if (!t) { | ||
| 277 | return (struct ggml_metal_buffer_id) { nil, 0 }; | ||
| 278 | } | ||
| 279 | |||
| 280 | ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; | ||
| 281 | |||
| 282 | return ggml_metal_buffer_get_id(buffer->context, t); | ||
| 283 | } | ||
| 284 | |||
| 285 | void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 286 | @autoreleasepool { | ||
| 287 | // wrap the source data into a Metal buffer | ||
| 288 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 289 | id<MTLBuffer> buf_src = [device newBufferWithBytes:data | ||
| 290 | length:size | ||
| 291 | options:MTLResourceStorageModeShared]; | ||
| 292 | |||
| 293 | GGML_ASSERT(buf_src); | ||
| 294 | |||
| 295 | struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor); | ||
| 296 | if (bid_dst.metal == nil) { | ||
| 297 | GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); | ||
| 298 | } | ||
| 299 | |||
| 300 | bid_dst.offs += offset; | ||
| 301 | |||
| 302 | // queue the copy operation into the queue of the Metal context | ||
| 303 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 304 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 305 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 306 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 307 | |||
| 308 | [encoder copyFromBuffer:buf_src | ||
| 309 | sourceOffset:0 | ||
| 310 | toBuffer:bid_dst.metal | ||
| 311 | destinationOffset:bid_dst.offs | ||
| 312 | size:size]; | ||
| 313 | |||
| 314 | [encoder endEncoding]; | ||
| 315 | [cmd_buf commit]; | ||
| 316 | [buf_src release]; | ||
| 317 | |||
| 318 | // do not wait here for completion | ||
| 319 | //[cmd_buf waitUntilCompleted]; | ||
| 320 | |||
| 321 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 322 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 323 | ctx->cmd_buf_last = cmd_buf; | ||
| 324 | |||
| 325 | [cmd_buf retain]; | ||
| 326 | } | ||
| 327 | } | ||
| 328 | |||
| 329 | void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 330 | @autoreleasepool { | ||
| 331 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 332 | id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data | ||
| 333 | length:size | ||
| 334 | options:MTLResourceStorageModeShared | ||
| 335 | deallocator:nil]; | ||
| 336 | |||
| 337 | GGML_ASSERT(buf_dst); | ||
| 338 | |||
| 339 | struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor); | ||
| 340 | if (bid_src.metal == nil) { | ||
| 341 | GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); | ||
| 342 | } | ||
| 343 | |||
| 344 | bid_src.offs += offset; | ||
| 345 | |||
| 346 | // queue the copy operation into the queue of the Metal context | ||
| 347 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 348 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 349 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 350 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 351 | |||
| 352 | [encoder copyFromBuffer:bid_src.metal | ||
| 353 | sourceOffset:bid_src.offs | ||
| 354 | toBuffer:buf_dst | ||
| 355 | destinationOffset:0 | ||
| 356 | size:size]; | ||
| 357 | |||
| 358 | [encoder endEncoding]; | ||
| 359 | [cmd_buf commit]; | ||
| 360 | [buf_dst release]; | ||
| 361 | |||
| 362 | // do not wait here for completion | ||
| 363 | //[cmd_buf waitUntilCompleted]; | ||
| 364 | |||
| 365 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 366 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 367 | ctx->cmd_buf_last = cmd_buf; | ||
| 368 | |||
| 369 | [cmd_buf retain]; | ||
| 370 | } | ||
| 371 | } | ||
| 372 | |||
| 373 | bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { | ||
| 374 | @autoreleasepool { | ||
| 375 | struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src); | ||
| 376 | struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst); | ||
| 377 | |||
| 378 | if (bid_src.metal == nil || bid_dst.metal == nil) { | ||
| 379 | return false; | ||
| 380 | } | ||
| 381 | |||
| 382 | // queue the copy operation into the Metal context | ||
| 383 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 384 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev); | ||
| 385 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 386 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 387 | |||
| 388 | [encoder copyFromBuffer:bid_src.metal | ||
| 389 | sourceOffset:bid_src.offs | ||
| 390 | toBuffer:bid_dst.metal | ||
| 391 | destinationOffset:bid_dst.offs | ||
| 392 | size:ggml_nbytes(src)]; | ||
| 393 | |||
| 394 | [encoder endEncoding]; | ||
| 395 | |||
| 396 | ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); | ||
| 397 | ggml_metal_event_encode_signal(ev_cpy, cmd_buf); | ||
| 398 | |||
| 399 | [cmd_buf commit]; | ||
| 400 | |||
| 401 | // do not wait here for completion | ||
| 402 | //[cmd_buf waitUntilCompleted]; | ||
| 403 | |||
| 404 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 405 | [ctx_src->cmd_bufs_ext addObject:cmd_buf]; | ||
| 406 | ctx_src->cmd_buf_last = cmd_buf; | ||
| 407 | |||
| 408 | [cmd_buf retain]; | ||
| 409 | |||
| 410 | ggml_metal_event_wait(ctx_dst, ev_cpy); | ||
| 411 | |||
| 412 | return true; | ||
| 413 | } | ||
| 414 | } | ||
| 415 | |||
| 416 | enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { | ||
| 417 | // number of nodes encoded by the main thread (empirically determined) | ||
| 418 | const int n_main = MAX(64, 0.1*gf->n_nodes); | ||
| 419 | |||
| 420 | // number of threads in addition to the main thread | ||
| 421 | const int n_cb = ctx->n_cb; | ||
| 422 | |||
| 423 | // keep the memory wired | ||
| 424 | ggml_metal_device_rsets_keep_alive(ctx->dev); | ||
| 425 | |||
| 426 | // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them | ||
| 427 | // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread | ||
| 428 | // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes | ||
| 429 | // each thread creates it's own command buffer and enqueues the ops in parallel | ||
| 430 | // | ||
| 431 | // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 | ||
| 432 | |||
| 433 | @autoreleasepool { | ||
| 434 | ctx->gf = gf; | ||
| 435 | |||
| 436 | ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); | ||
| 437 | ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; | ||
| 438 | |||
| 439 | ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; | ||
| 440 | |||
| 441 | const bool use_capture = ctx->capture_next_compute; | ||
| 442 | if (use_capture) { | ||
| 443 | ctx->capture_next_compute = false; | ||
| 444 | |||
| 445 | // make sure all previous computations have finished before starting the capture | ||
| 446 | if (ctx->cmd_buf_last) { | ||
| 447 | [ctx->cmd_buf_last waitUntilCompleted]; | ||
| 448 | ctx->cmd_buf_last = nil; | ||
| 449 | } | ||
| 450 | |||
| 451 | if (!ctx->capture_started) { | ||
| 452 | // create capture scope | ||
| 453 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 454 | ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; | ||
| 455 | |||
| 456 | MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; | ||
| 457 | descriptor.captureObject = ctx->capture_scope; | ||
| 458 | descriptor.destination = MTLCaptureDestinationGPUTraceDocument; | ||
| 459 | descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; | ||
| 460 | |||
| 461 | NSError * error = nil; | ||
| 462 | if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { | ||
| 463 | GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); | ||
| 464 | } else { | ||
| 465 | [ctx->capture_scope beginScope]; | ||
| 466 | ctx->capture_started = true; | ||
| 467 | } | ||
| 468 | } | ||
| 469 | } | ||
| 470 | |||
| 471 | // short-hand | ||
| 472 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 473 | |||
| 474 | // the main thread commits the first few commands immediately | ||
| 475 | // cmd_buf[n_cb] | ||
| 476 | { | ||
| 477 | id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences]; | ||
| 478 | [cmd_buf retain]; | ||
| 479 | |||
| 480 | if (ctx->cmd_bufs[n_cb].obj) { | ||
| 481 | [ctx->cmd_bufs[n_cb].obj release]; | ||
| 482 | } | ||
| 483 | ctx->cmd_bufs[n_cb].obj = cmd_buf; | ||
| 484 | |||
| 485 | [cmd_buf enqueue]; | ||
| 486 | |||
| 487 | ctx->encode_async(n_cb); | ||
| 488 | } | ||
| 489 | |||
| 490 | // remember the command buffer for the next iteration | ||
| 491 | ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; | ||
| 492 | |||
| 493 | // prepare the rest of the command buffers asynchronously (optional) | ||
| 494 | // cmd_buf[0.. n_cb) | ||
| 495 | for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { | ||
| 496 | id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences]; | ||
| 497 | [cmd_buf retain]; | ||
| 498 | |||
| 499 | if (ctx->cmd_bufs[cb_idx].obj) { | ||
| 500 | [ctx->cmd_bufs[cb_idx].obj release]; | ||
| 501 | } | ||
| 502 | ctx->cmd_bufs[cb_idx].obj = cmd_buf; | ||
| 503 | |||
| 504 | // always enqueue the first two command buffers | ||
| 505 | // enqueue all of the command buffers if we don't need to abort | ||
| 506 | if (cb_idx < 2 || ctx->abort_callback == NULL) { | ||
| 507 | [cmd_buf enqueue]; | ||
| 508 | |||
| 509 | // update the pointer to the last queued command buffer | ||
| 510 | // this is needed to implement synchronize() | ||
| 511 | ctx->cmd_buf_last = cmd_buf; | ||
| 512 | } | ||
| 513 | } | ||
| 514 | |||
| 515 | dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); | ||
| 516 | |||
| 517 | // for debugging: block until graph is computed | ||
| 518 | //[ctx->cmd_buf_last waitUntilCompleted]; | ||
| 519 | |||
| 520 | // enter here only when capturing in order to wait for all computation to finish | ||
| 521 | // otherwise, we leave the graph to compute asynchronously | ||
| 522 | if (!use_capture && ctx->capture_started) { | ||
| 523 | // wait for completion and check status of each command buffer | ||
| 524 | // needed to detect if the device ran out-of-memory for example (#1881) | ||
| 525 | { | ||
| 526 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj; | ||
| 527 | [cmd_buf waitUntilCompleted]; | ||
| 528 | |||
| 529 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 530 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 531 | GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); | ||
| 532 | if (status == MTLCommandBufferStatusError) { | ||
| 533 | GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 534 | } | ||
| 535 | |||
| 536 | return GGML_STATUS_FAILED; | ||
| 537 | } | ||
| 538 | } | ||
| 539 | |||
| 540 | for (int i = 0; i < n_cb; ++i) { | ||
| 541 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj; | ||
| 542 | [cmd_buf waitUntilCompleted]; | ||
| 543 | |||
| 544 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 545 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 546 | GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); | ||
| 547 | if (status == MTLCommandBufferStatusError) { | ||
| 548 | GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 549 | } | ||
| 550 | |||
| 551 | return GGML_STATUS_FAILED; | ||
| 552 | } | ||
| 553 | |||
| 554 | id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); | ||
| 555 | if (!next_buffer) { | ||
| 556 | continue; | ||
| 557 | } | ||
| 558 | |||
| 559 | const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); | ||
| 560 | if (next_queued) { | ||
| 561 | continue; | ||
| 562 | } | ||
| 563 | |||
| 564 | if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { | ||
| 565 | GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); | ||
| 566 | return GGML_STATUS_ABORTED; | ||
| 567 | } | ||
| 568 | |||
| 569 | [next_buffer commit]; | ||
| 570 | } | ||
| 571 | |||
| 572 | [ctx->capture_scope endScope]; | ||
| 573 | [[MTLCaptureManager sharedCaptureManager] stopCapture]; | ||
| 574 | } | ||
| 575 | } | ||
| 576 | |||
| 577 | return GGML_STATUS_SUCCESS; | ||
| 578 | } | ||
| 579 | |||
| 580 | void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { | ||
| 581 | //const int64_t t_start = ggml_time_us(); | ||
| 582 | |||
| 583 | if (ctx->use_graph_optimize) { | ||
| 584 | ggml_graph_optimize(gf); | ||
| 585 | } | ||
| 586 | |||
| 587 | //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); | ||
| 588 | } | ||
| 589 | |||
| 590 | void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) { | ||
| 591 | @autoreleasepool { | ||
| 592 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 593 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 594 | |||
| 595 | ggml_metal_event_encode_signal(ev, cmd_buf); | ||
| 596 | |||
| 597 | [cmd_buf commit]; | ||
| 598 | |||
| 599 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 600 | ctx->cmd_buf_last = cmd_buf; | ||
| 601 | |||
| 602 | [cmd_buf retain]; | ||
| 603 | } | ||
| 604 | } | ||
| 605 | |||
| 606 | void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) { | ||
| 607 | @autoreleasepool { | ||
| 608 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 609 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 610 | |||
| 611 | ggml_metal_event_encode_wait(ev, cmd_buf); | ||
| 612 | |||
| 613 | [cmd_buf commit]; | ||
| 614 | |||
| 615 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 616 | ctx->cmd_buf_last = cmd_buf; | ||
| 617 | |||
| 618 | [cmd_buf retain]; | ||
| 619 | } | ||
| 620 | } | ||
| 621 | |||
| 622 | ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) { | ||
| 623 | return ctx->ev_cpy; | ||
| 624 | } | ||
| 625 | |||
| 626 | void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { | ||
| 627 | if (ctx->n_cb != n_cb) { | ||
| 628 | ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); | ||
| 629 | |||
| 630 | if (ctx->n_cb > 2) { | ||
| 631 | GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); | ||
| 632 | } | ||
| 633 | } | ||
| 634 | |||
| 635 | if (ctx->encode_async) { | ||
| 636 | Block_release(ctx->encode_async); | ||
| 637 | } | ||
| 638 | |||
| 639 | ctx->encode_async = Block_copy(^(size_t iter) { | ||
| 640 | const int cb_idx = iter; | ||
| 641 | const int n_cb_l = ctx->n_cb; | ||
| 642 | |||
| 643 | const int n_nodes_0 = ctx->n_nodes_0; | ||
| 644 | const int n_nodes_1 = ctx->n_nodes_1; | ||
| 645 | |||
| 646 | const int n_nodes_per_cb = ctx->n_nodes_per_cb; | ||
| 647 | |||
| 648 | int idx_start = 0; | ||
| 649 | int idx_end = n_nodes_0; | ||
| 650 | |||
| 651 | if (cb_idx < n_cb_l) { | ||
| 652 | idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); | ||
| 653 | idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); | ||
| 654 | } | ||
| 655 | |||
| 656 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj; | ||
| 657 | |||
| 658 | ggml_metal_op_t ctx_op = ggml_metal_op_init( | ||
| 659 | ctx->dev, | ||
| 660 | cmd_buf, | ||
| 661 | ctx->gf, | ||
| 662 | idx_start, | ||
| 663 | idx_end, | ||
| 664 | ctx->use_fusion, | ||
| 665 | ctx->use_concurrency, | ||
| 666 | ctx->capture_next_compute, | ||
| 667 | ctx->debug_graph, | ||
| 668 | ctx->debug_fusion); | ||
| 669 | |||
| 670 | for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) { | ||
| 671 | const int res = ggml_metal_op_encode(ctx_op, idx); | ||
| 672 | if (res == 0) { | ||
| 673 | break; | ||
| 674 | } | ||
| 675 | |||
| 676 | idx += res - 1; | ||
| 677 | } | ||
| 678 | |||
| 679 | ggml_metal_op_free(ctx_op); | ||
| 680 | |||
| 681 | if (cb_idx < 2 || ctx->abort_callback == NULL) { | ||
| 682 | [cmd_buf commit]; | ||
| 683 | } | ||
| 684 | }); | ||
| 685 | } | ||
| 686 | |||
| 687 | void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) { | ||
| 688 | ctx->abort_callback = abort_callback; | ||
| 689 | ctx->abort_callback_data = user_data; | ||
| 690 | } | ||
| 691 | |||
| 692 | bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { | ||
| 693 | GGML_ASSERT(ctx->dev != nil); | ||
| 694 | |||
| 695 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 696 | |||
| 697 | return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; | ||
| 698 | } | ||
| 699 | |||
| 700 | void ggml_metal_capture_next_compute(ggml_metal_t ctx) { | ||
| 701 | ctx->capture_next_compute = true; | ||
| 702 | } | ||
