Merge pull request #22882 from wenscarl:attn_layout_fix

PiperOrigin-RevId: 668636119
This commit is contained in:
jax authors 2024-08-28 15:44:23 -07:00
commit 2785a08ca9
2 changed files with 18 additions and 9 deletions

View File

@ -284,8 +284,8 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
def check_is_flash_attention(
query, key, layout, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH:
query, key, layout: int, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
_, _, S, _ = key.shape
else:

View File

@ -429,18 +429,27 @@ class DotProductAttentionTest(jtu.JaxTestCase):
def test_sdpa_utils(self):
test_cases = [
(1, 257, 64, 8905, False, True),
(1, 1024, 64, 8905, False, False),
(1024, 1024, 64, 8905, False, False),
(1024, 1024, 128, 8905, False, False),
(1, 257, 64, 8905, False, True, True),
(1, 1024, 64, 8905, False, False, True),
(1024, 1024, 64, 8905, False, False, True),
(1024, 1024, 128, 8905, False, False, True),
(1024, 1024, 127, 8905, False, False, False),
]
for k in test_cases:
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \
expected_pass = k
query = jnp.empty((4, sql_q, 4, head_dim))
key = jnp.empty((4, sql_v, 4, head_dim))
check_is_flash_attention(
query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training)
if expected_pass:
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)
else:
with self.assertRaises(NotImplementedError):
check_is_flash_attention(
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())