diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 862b9b666..3189ae85d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1,6 +1,7 @@ #include "ggml-rpc.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cpp.h" #include #include @@ -853,12 +854,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); - ggml_free(ctx); return false; } @@ -871,7 +873,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); - ggml_free(ctx); return true; } @@ -985,11 +986,12 @@ bool rpc_server::set_tensor(const std::vector & input) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); @@ -1016,7 +1018,6 @@ bool rpc_server::set_tensor(const std::vector & input) { printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); - ggml_free(ctx); return true; } @@ -1060,11 +1061,12 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); @@ -1080,7 +1082,6 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set } ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); response.result = 1; - ggml_free(ctx); return true; } @@ -1090,11 +1091,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); - ggml_free(ctx); return false; } @@ -1110,11 +1112,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. // Currently unimplemented. GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); - ggml_free(ctx); return false; } - ggml_free(ctx); return true; } @@ -1124,11 +1124,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); @@ -1147,7 +1148,6 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< response.resize(request.size, 0); ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); - ggml_free(ctx); return true; } @@ -1157,12 +1157,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); - ggml_free(ctx); return false; } @@ -1180,7 +1182,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co dst_data + src_size, dst_base, dst_base + dst_buf_sz); - ggml_free(ctx); return false; } @@ -1188,7 +1189,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co __func__, (void*) src->buffer, (void*) dst->buffer); response.result = ggml_backend_buffer_copy_tensor(src, dst); - ggml_free(ctx); return true; } @@ -1242,7 +1242,9 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; @@ -1257,7 +1259,6 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } ggml_status status = ggml_backend_graph_compute(backend, graph); response.result = status; - ggml_free(ctx); return true; }