diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index b2d857e1e..6bb5d0834 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); + int64_t antiquantGroupSize = 0; + if (src0->ne[0] > QK8_0) { + antiquantGroupSize = QK8_0; + } ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, - nullptr, nullptr, nullptr, QK8_0, acl_output_tensor, + nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor, &workspaceSize, &executor)); if (workspaceAddr == nullptr) { workspaceAddr = workspace_allocator.alloc(workspaceSize); @@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( acl_input_tensor, acl_weight_tensor, acl_scale_tensor, - nullptr, nullptr, nullptr, nullptr, QK8_0, + nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor, &workspaceSize, &executor)); ACL_CHECK(aclnnWeightQuantBatchMatmulV2( workspaceAddr, workspaceSize, executor, ctx.stream())); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b8d272cda..68cd9920d 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { case GGML_TYPE_Q8_0: - // Current groupsize should not be greater than k-1 in - // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize - if (op->src[0]->ne[0] <= QK8_0) { - return false; - } case GGML_TYPE_F16: case GGML_TYPE_F32: case GGML_TYPE_Q4_0: