mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
Simplify and improve CUDA graphs through use of indirect copy pointers (#9017)
* CUDA: Simplify and improve CUDA graphs through use of indirect copy pointers Previously there was complexity in the CUDA graphs implementation due frequently changing parameters to copy kernels associated with K and V cache pointers. This patch simplifies by using indirection to avoid such parameters frequently changing, avoiding the need for frequent graph updates. Fixes #12152 * Addressed comments * fix HIP builds * properly sync to stream * removed ggml_cuda_cpy_fn_ptrs * move stream sync before free * guard to only use indirection with graphs * style fixes * check for errors --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
2a0dc97e56
commit
3f9da22c2b
@ -729,7 +729,13 @@ struct ggml_cuda_graph {
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
std::vector<char **> updated_kernel_arg;
|
||||
bool use_cpy_indirection = false;
|
||||
std::vector<char *> cpy_dest_ptrs;
|
||||
char ** dest_ptrs_d;
|
||||
int dest_ptrs_size = 0;
|
||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||
// relative to other cpy nodes.
|
||||
int graph_cpynode_index = -1;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -32,16 +32,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
@ -288,16 +290,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
@ -314,16 +318,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
@ -339,66 +345,84 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
||||
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
||||
}
|
||||
// copy destination pointers to GPU
|
||||
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
||||
cuda_graph->graph_cpynode_index = 0; // reset index
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
@ -407,22 +431,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
@ -431,22 +455,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
@ -455,22 +479,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
@ -479,32 +503,32 @@ static void ggml_cpy_q5_1_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
@ -541,46 +565,60 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection) {
|
||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
||||
}
|
||||
#endif
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection) {
|
||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -7,3 +7,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||
|
@ -2441,10 +2441,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
|
||||
bool use_cuda_graph) {
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@ -2476,8 +2477,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
// store the copy op parameter which changes with each token.
|
||||
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
// to the device and accessed using indirection from CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (!ptr) {
|
||||
@ -2485,10 +2489,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||
#endif
|
||||
} else {
|
||||
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
|
||||
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2497,6 +2497,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
||||
}
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
@ -2551,51 +2557,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
return true;
|
||||
}
|
||||
|
||||
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
|
||||
|
||||
if (cuda_graph_update_required) {
|
||||
// Extract nodes from graph
|
||||
// First call with null argument gets number of nodes in graph
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
||||
// Subsequent call with non-null argument gets nodes
|
||||
cuda_ctx->cuda_graph->nodes.clear();
|
||||
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||
cuda_ctx->cuda_graph->params.clear();
|
||||
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
||||
|
||||
// Loop over nodes, and extract kernel parameters from each node
|
||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
cudaGraphNodeType node_type;
|
||||
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
||||
if (node_type == cudaGraphNodeTypeKernel) {
|
||||
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
||||
if (stat == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
(void)cudaGetLastError();
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
||||
// replace that argument with the updated value in the CUDA graph
|
||||
// on update steps, the live parameters will already be captured
|
||||
int k = 0;
|
||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
||||
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
||||
*(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool cuda_graph_update_required = false;
|
||||
@ -2655,8 +2616,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
#endif
|
||||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
[[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
|
||||
bool & cuda_graph_update_required) {
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
@ -2706,13 +2666,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
|
||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
|
||||
|
||||
// Update graph executable
|
||||
update_cuda_graph_executable(cuda_ctx);
|
||||
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
update_cuda_graph_executable(cuda_ctx);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
@ -2726,10 +2682,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
// vector of pointers to CUDA cpy kernels, which are required to identify
|
||||
// kernel parameters which need updated in the graph for each token
|
||||
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
@ -2763,8 +2715,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
|
||||
ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
@ -2785,6 +2736,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
@ -2792,7 +2747,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user