diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f48c64605..238e5d86b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -471,17 +471,55 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COUNT }; -// TODO: use MTLHeapTypePlacement and reset offset after every node struct ggml_metal_heap { int n; int fail; size_t need; + id device; id obj; id bufs[GGML_METAL_MAX_HEAP_BUFFERS]; }; +static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { + struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); + + MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; + desc.storageMode = MTLStorageModePrivate; + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement + desc.size = size; + + heap->device = device; + heap->obj = [device newHeapWithDescriptor:desc]; + if (!heap->obj) { + GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); + + free(heap); + + return false; + } + + for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) { + heap->bufs[i] = nil; + } + + [desc release]; + + return heap; +} + +static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { + if (heap == nil) { + return; + } + + [heap->obj release]; + + free(heap); +} + static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { heap->n = 0; heap->fail = 0; @@ -498,6 +536,33 @@ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { } } +static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) { + if (heap == nil) { + return false; + } + + [heap->obj release]; + + MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; + desc.storageMode = MTLStorageModePrivate; + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement + desc.size = size; + + heap->obj = [heap->device newHeapWithDescriptor:desc]; + if (!heap->obj) { + GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); + + return false; + } + + [desc release]; + + ggml_metal_heap_reset(heap); + + return true; +} + static id ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t size, size_t alignment) { const size_t size_aligned = GGML_PAD(size, alignment); @@ -531,7 +596,7 @@ struct ggml_backend_metal_context { id queue; // TODO: create heap per command buffer - struct ggml_metal_heap heap; + struct ggml_metal_heap * heap; dispatch_queue_t d_queue; @@ -757,24 +822,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - // allocate tmp heap with fixed size for testing - // TODO: factor into a function - { - MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; - desc.storageMode = MTLStorageModePrivate; - desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; - desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement - desc.size = 1024*1024; - - ctx->heap.n = 0; - - ctx->heap.obj = [device newHeapWithDescriptor:desc]; - for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) { - ctx->heap.bufs[i] = nil; - } - - [desc release]; - } + ctx->heap = ggml_metal_heap_init(device, 1024*1024); // load library if (ctx_dev->mtl_library == nil) { @@ -1218,8 +1266,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { Block_release(ctx->encode_async); - [ctx->queue release]; - [ctx->heap.obj release]; + [ctx->queue release]; + + ggml_metal_heap_free(ctx->heap); dispatch_release(ctx->d_queue); @@ -2217,12 +2266,6 @@ static void ggml_metal_encode_node( /*.nb3 =*/ nb03, }; - //id id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate]; - - //// save a reference to the heap-allocated buffer - //// TODO: simplify and check for available resources - //heap->bufs[heap->n++] = id_src0h; - id id_src0h = ggml_metal_heap_alloc(heap, ggml_nbytes(src0), 32); if (!id_src0h) { break; @@ -4717,7 +4760,7 @@ static enum ggml_status ggml_metal_graph_compute( [command_buffer waitUntilCompleted]; // TODO: per command buffer heap - ggml_metal_heap_reset(&ctx->heap); + ggml_metal_heap_reset(ctx->heap); MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { @@ -5134,7 +5177,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - ggml_metal_encode_node(backend, idx, encoder, &ctx->heap); + ggml_metal_encode_node(backend, idx, encoder, ctx->heap); if (should_capture) { [encoder popDebugGroup]; @@ -5143,28 +5186,18 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder endEncoding]; - if (ctx->heap.fail == 0) { + if (ctx->heap->fail == 0) { break; } - // increase heap size - [ctx->heap.obj release]; + const size_t need = ctx->heap->need; - { - MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; - desc.storageMode = MTLStorageModePrivate; - desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; - desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement - desc.size = ctx->heap.need; + GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need); - GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, ctx->heap.need); - - ctx->heap.obj = [ctx->device newHeapWithDescriptor:desc]; - - [desc release]; + if (!ggml_metal_heap_resize(ctx->heap, need)) { + GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need); + break; } - - ggml_metal_heap_reset(&ctx->heap); } if (cb_idx < 2 || ctx->abort_callback == NULL) {