diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 238e5d86b..cb9523506 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -591,13 +591,16 @@ static id ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t return buf; } +struct ggml_metal_command_buffer { + id obj; + + struct ggml_metal_heap * heap; +}; + struct ggml_backend_metal_context { id device; id queue; - // TODO: create heap per command buffer - struct ggml_metal_heap * heap; - dispatch_queue_t d_queue; struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; @@ -620,7 +623,8 @@ struct ggml_backend_metal_context { void (^encode_async)(size_t ith); // n_cb command buffers + 1 used by the main thread - id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + //id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; @@ -822,8 +826,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - ctx->heap = ggml_metal_heap_init(device, 1024*1024); - // load library if (ctx_dev->mtl_library == nil) { ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); @@ -877,7 +879,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->gf = nil; ctx->encode_async = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - ctx->command_buffers[i] = nil; + ctx->cmd_bufs[i].obj = nil; + + // create 1MB heaps per command buffer + // these can be resized during compute when necessary + ctx->cmd_bufs[i].heap = ggml_metal_heap_init(device, 1024*1024); } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) @@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->queue release]; - ggml_metal_heap_free(ctx->heap); + //ggml_metal_heap_free(ctx->heap); + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + [ctx->cmd_bufs[i].obj release]; + ggml_metal_heap_free(ctx->cmd_bufs[i].heap); + } dispatch_release(ctx->d_queue); @@ -4712,25 +4722,25 @@ static enum ggml_status ggml_metal_graph_compute( } // the main thread commits the first few commands immediately - // command_buffer[n_cb] + // cmd_buf[n_cb] { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[n_cb] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; - [command_buffer enqueue]; + [cmd_buf enqueue]; ctx->encode_async(n_cb); } // prepare the rest of the command buffers asynchronously - // command_buffer[0.. n_cb) + // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[cb_idx] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; + [cmd_buf enqueue]; } } @@ -4739,16 +4749,16 @@ static enum ggml_status ggml_metal_graph_compute( // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { - id command_buffer = ctx->command_buffers[n_cb]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; - // TODO: free main cb heap + ggml_metal_heap_reset(ctx->cmd_bufs[n_cb].heap); - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; @@ -4756,23 +4766,22 @@ static enum ggml_status ggml_metal_graph_compute( } for (int i = 0; i < n_cb; ++i) { - id command_buffer = ctx->command_buffers[i]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; - // TODO: per command buffer heap - ggml_metal_heap_reset(ctx->heap); + ggml_metal_heap_reset(ctx->cmd_bufs[i].heap); - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; } - id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); if (!next_buffer) { continue; } @@ -5155,12 +5164,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id command_buffer = ctx->command_buffers[cb_idx]; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + struct ggml_metal_heap * heap = ctx->cmd_bufs[cb_idx].heap; int n_try = 3; while (n_try-- > 0) { - id encoder = [command_buffer computeCommandEncoder]; + id encoder = [cmd_buf computeCommandEncoder]; int node_start = 0; int node_end = n_nodes_0; @@ -5177,7 +5187,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - ggml_metal_encode_node(backend, idx, encoder, ctx->heap); + ggml_metal_encode_node(backend, idx, encoder, heap); if (should_capture) { [encoder popDebugGroup]; @@ -5186,22 +5196,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder endEncoding]; - if (ctx->heap->fail == 0) { + if (heap->fail == 0) { break; } - const size_t need = ctx->heap->need; + const size_t need = heap->need; GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need); - if (!ggml_metal_heap_resize(ctx->heap, need)) { + if (!ggml_metal_heap_resize(heap, need)) { GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need); break; } } if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [cmd_buf commit]; } }); }