CUDA: Prefer vector flash decoding kernel for Gemma models (#12738)

* Prefer vector flash decoding kernel for Gemma models

Vector flash decoding kernel was not being picked for models with head dimension 256. Gemma models are in this category.
Removing this limit improves e2e performance by upto 12% in gen phase throughput for Gemm models.

* Update ggml/src/ggml-cuda/fattn.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Gaurav Garg 2025-04-03 21:50:29 +05:30 committed by GitHub
parent 5dd5d1ab00
commit c262beddf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);