vulkan: Use unclamped loads for flash attention mask (#12720)

nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple
of the number of rows in the matrix. The KV dim is a multiple of the number of
columns for the aligned shader.
This commit is contained in:
Jeff Bolz 2025-04-06 03:47:13 -05:00 committed by GitHub
parent 6bf28f0111
commit 80b717d493
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -1833,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
// can't use 256 for D==80.
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
};
@ -5511,6 +5513,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
vk_pipeline pipeline = pipelines[aligned];
assert(pipeline);

View File

@ -256,7 +256,7 @@ void main() {
}
if (p.mask != 0) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
// When using grouped query attention, all rows use the same mask.
if (p.gqa_ratio > 1) {