1
0
mirror of https://github.com/ggerganov/llama.cpp.git synced 2025-04-20 21:46:07 +00:00

cuda : enable CUDA Graph on CUDA Toolkit < 12.x ()

* Enable CUDA Graph on CTK < 12.x

`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.

* Fix compilation errors with MUSA

* Disable CUDA Graph for MUSA
This commit is contained in:
Gaurav Garg 2025-03-17 23:55:13 +05:30 committed by GitHub
parent 01e8f2138b
commit b1b132efcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 11 additions and 12 deletions
ggml/src

@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
};
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
#define USE_CUDA_GRAPH
#endif

@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t errorNode;
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#endif
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);

@ -112,7 +112,7 @@
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate

@ -119,7 +119,7 @@
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@ -132,6 +132,7 @@
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;

@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
add_compile_definitions(GGML_USE_MUSA)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_GRAPHS)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()