metal : add memory pool for temp allocs (wip) [no ci]

This commit is contained in:
Georgi Gerganov 2025-04-09 14:50:41 +03:00
parent 526739b879
commit c254b21307
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
// note: assumes single GPU device - the default one
// TODO: support multiple GPU devices
static struct ggml_backend_metal_device_context {
id<MTLDevice> mtl_device;
int mtl_device_ref_count;
id<MTLDevice> mtl_device;
int mtl_device_ref_count;
id<MTLLibrary> mtl_library;
bool has_simdgroup_reduction;
@ -470,6 +470,7 @@ enum ggml_metal_kernel_type {
struct ggml_backend_metal_context {
id<MTLCommandQueue> queue;
id<MTLHeap> heap;
dispatch_queue_t d_queue;
@ -693,6 +694,19 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
// allocate tmp heap with fixed size for testing
// TODO: figure out how to dynamically resize it
{
MTLHeapDescriptor *heapDescriptor = [[MTLHeapDescriptor alloc] init];
heapDescriptor.storageMode = MTLStorageModePrivate;
heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache;
heapDescriptor.size = 32*1024*1024;
ctx->heap = [device newHeapWithDescriptor:heapDescriptor];
[heapDescriptor release];
}
// load library
if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
@ -1136,6 +1150,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
Block_release(ctx->encode_async);
[ctx->queue release];
[ctx->heap release];
dispatch_release(ctx->d_queue);
@ -1439,7 +1454,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
static void ggml_metal_encode_node(
ggml_backend_t backend,
int idx,
id<MTLComputeCommandEncoder> encoder) {
id<MTLComputeCommandEncoder> encoder,
id<MTLHeap> heap) {
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
@ -2111,26 +2127,65 @@ static void ggml_metal_encode_node(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_metal_kargs_soft_max args = {
// cpy to tmp buffer in MTLHeap
ggml_metal_kargs_cpy args_cpy = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne00,
/*.ne1 =*/ ne01,
/*.ne2 =*/ ne02,
/*.ne3 =*/ ne03,
/*.nb0 =*/ nb00,
/*.nb1 =*/ nb01,
/*.nb2 =*/ nb02,
/*.nb3 =*/ nb03,
};
id<MTLBuffer> id_src0h = [heap newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
if (src0->type == GGML_TYPE_F16) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
}
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src0h offset:0 atIndex:2];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
// softmax
ggml_metal_kargs_soft_max args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.n_head_log2 =*/ n_head_log2,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src0h offset:0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src0h offset:0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
@ -4992,7 +5047,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(backend, idx, encoder);
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
if (should_capture) {
[encoder popDebugGroup];