mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix lint
This commit is contained in:
parent
ba6b1fdd09
commit
8c4d6d6903
@ -1789,7 +1789,7 @@ def dot_product_attention(
|
||||
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
|
||||
)
|
||||
check_fp8_params(fp8_params)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
|
||||
output, amax_s, amax_o = _dot_product_attention_fp8(
|
||||
query, key, value, fp8_params,
|
||||
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user