This commit is contained in:
cjkkkk 2025-02-03 06:09:05 +00:00
parent ba6b1fdd09
commit 8c4d6d6903

View File

@ -1789,7 +1789,7 @@ def dot_product_attention(
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
)
check_fp8_params(fp8_params)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
output, amax_s, amax_o = _dot_product_attention_fp8(
query, key, value, fp8_params,
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version