cont : heap for each cmd buffer [no ci]

This commit is contained in:
Georgi Gerganov 2025-04-10 14:56:47 +03:00
parent 9433c504c0
commit 2804db7812
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -591,13 +591,16 @@ static id<MTLBuffer> ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t
return buf;
}
struct ggml_metal_command_buffer {
id<MTLCommandBuffer> obj;
struct ggml_metal_heap * heap;
};
struct ggml_backend_metal_context {
id<MTLDevice> device;
id<MTLCommandQueue> queue;
// TODO: create heap per command buffer
struct ggml_metal_heap * heap;
dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@ -620,7 +623,8 @@ struct ggml_backend_metal_context {
void (^encode_async)(size_t ith);
// n_cb command buffers + 1 used by the main thread
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
//id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
@ -822,8 +826,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
ctx->heap = ggml_metal_heap_init(device, 1024*1024);
// load library
if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
@ -877,7 +879,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->gf = nil;
ctx->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->command_buffers[i] = nil;
ctx->cmd_bufs[i].obj = nil;
// create 1MB heaps per command buffer
// these can be resized during compute when necessary
ctx->cmd_bufs[i].heap = ggml_metal_heap_init(device, 1024*1024);
}
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->queue release];
ggml_metal_heap_free(ctx->heap);
//ggml_metal_heap_free(ctx->heap);
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
[ctx->cmd_bufs[i].obj release];
ggml_metal_heap_free(ctx->cmd_bufs[i].heap);
}
dispatch_release(ctx->d_queue);
@ -4712,25 +4722,25 @@ static enum ggml_status ggml_metal_graph_compute(
}
// the main thread commits the first few commands immediately
// command_buffer[n_cb]
// cmd_buf[n_cb]
{
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[n_cb] = command_buffer;
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[n_cb].obj = cmd_buf;
[command_buffer enqueue];
[cmd_buf enqueue];
ctx->encode_async(n_cb);
}
// prepare the rest of the command buffers asynchronously
// command_buffer[0.. n_cb)
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[cb_idx] = command_buffer;
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
[cmd_buf enqueue];
}
}
@ -4739,16 +4749,16 @@ static enum ggml_status ggml_metal_graph_compute(
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
[command_buffer waitUntilCompleted];
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[cmd_buf waitUntilCompleted];
// TODO: free main cb heap
ggml_metal_heap_reset(ctx->cmd_bufs[n_cb].heap);
MTLCommandBufferStatus status = [command_buffer status];
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
@ -4756,23 +4766,22 @@ static enum ggml_status ggml_metal_graph_compute(
}
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
[command_buffer waitUntilCompleted];
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[cmd_buf waitUntilCompleted];
// TODO: per command buffer heap
ggml_metal_heap_reset(ctx->heap);
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
MTLCommandBufferStatus status = [command_buffer status];
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
}
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) {
continue;
}
@ -5155,12 +5164,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
struct ggml_metal_heap * heap = ctx->cmd_bufs[cb_idx].heap;
int n_try = 3;
while (n_try-- > 0) {
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
int node_start = 0;
int node_end = n_nodes_0;
@ -5177,7 +5187,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
ggml_metal_encode_node(backend, idx, encoder, heap);
if (should_capture) {
[encoder popDebugGroup];
@ -5186,22 +5196,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder endEncoding];
if (ctx->heap->fail == 0) {
if (heap->fail == 0) {
break;
}
const size_t need = ctx->heap->need;
const size_t need = heap->need;
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need);
if (!ggml_metal_heap_resize(ctx->heap, need)) {
if (!ggml_metal_heap_resize(heap, need)) {
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
break;
}
}
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
[cmd_buf commit];
}
});
}