cont : refactor heap [no ci]

This commit is contained in:
Georgi Gerganov 2025-04-10 14:49:49 +03:00
parent 37450314b5
commit 9433c504c0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -471,17 +471,55 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT
};
// TODO: use MTLHeapTypePlacement and reset offset after every node
struct ggml_metal_heap {
int n;
int fail;
size_t need;
id<MTLDevice> device;
id<MTLHeap> obj;
id<MTLBuffer> bufs[GGML_METAL_MAX_HEAP_BUFFERS];
};
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
desc.size = size;
heap->device = device;
heap->obj = [device newHeapWithDescriptor:desc];
if (!heap->obj) {
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
free(heap);
return false;
}
for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
heap->bufs[i] = nil;
}
[desc release];
return heap;
}
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
if (heap == nil) {
return;
}
[heap->obj release];
free(heap);
}
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
heap->n = 0;
heap->fail = 0;
@ -498,6 +536,33 @@ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
}
}
static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
if (heap == nil) {
return false;
}
[heap->obj release];
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
desc.size = size;
heap->obj = [heap->device newHeapWithDescriptor:desc];
if (!heap->obj) {
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
return false;
}
[desc release];
ggml_metal_heap_reset(heap);
return true;
}
static id<MTLBuffer> ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t size, size_t alignment) {
const size_t size_aligned = GGML_PAD(size, alignment);
@ -531,7 +596,7 @@ struct ggml_backend_metal_context {
id<MTLCommandQueue> queue;
// TODO: create heap per command buffer
struct ggml_metal_heap heap;
struct ggml_metal_heap * heap;
dispatch_queue_t d_queue;
@ -757,24 +822,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
// allocate tmp heap with fixed size for testing
// TODO: factor into a function
{
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
desc.size = 1024*1024;
ctx->heap.n = 0;
ctx->heap.obj = [device newHeapWithDescriptor:desc];
for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
ctx->heap.bufs[i] = nil;
}
[desc release];
}
ctx->heap = ggml_metal_heap_init(device, 1024*1024);
// load library
if (ctx_dev->mtl_library == nil) {
@ -1219,7 +1267,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
Block_release(ctx->encode_async);
[ctx->queue release];
[ctx->heap.obj release];
ggml_metal_heap_free(ctx->heap);
dispatch_release(ctx->d_queue);
@ -2217,12 +2266,6 @@ static void ggml_metal_encode_node(
/*.nb3 =*/ nb03,
};
//id<MTLBuffer> id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
//// save a reference to the heap-allocated buffer
//// TODO: simplify and check for available resources
//heap->bufs[heap->n++] = id_src0h;
id<MTLBuffer> id_src0h = ggml_metal_heap_alloc(heap, ggml_nbytes(src0), 32);
if (!id_src0h) {
break;
@ -4717,7 +4760,7 @@ static enum ggml_status ggml_metal_graph_compute(
[command_buffer waitUntilCompleted];
// TODO: per command buffer heap
ggml_metal_heap_reset(&ctx->heap);
ggml_metal_heap_reset(ctx->heap);
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
@ -5134,7 +5177,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(backend, idx, encoder, &ctx->heap);
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
if (should_capture) {
[encoder popDebugGroup];
@ -5143,28 +5186,18 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder endEncoding];
if (ctx->heap.fail == 0) {
if (ctx->heap->fail == 0) {
break;
}
// increase heap size
[ctx->heap.obj release];
const size_t need = ctx->heap->need;
{
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
desc.size = ctx->heap.need;
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need);
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, ctx->heap.need);
ctx->heap.obj = [ctx->device newHeapWithDescriptor:desc];
[desc release];
if (!ggml_metal_heap_resize(ctx->heap, need)) {
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
break;
}
ggml_metal_heap_reset(&ctx->heap);
}
if (cb_idx < 2 || ctx->abort_callback == NULL) {