mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22882 from wenscarl:attn_layout_fix
PiperOrigin-RevId: 668636119
This commit is contained in:
commit
2785a08ca9
@ -284,8 +284,8 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
|
||||
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
|
||||
|
||||
def check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, has_bias, is_training):
|
||||
if layout == AttentionLayout.BNTH:
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training):
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
_, _, T, H = query.shape
|
||||
_, _, S, _ = key.shape
|
||||
else:
|
||||
|
@ -429,18 +429,27 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sdpa_utils(self):
|
||||
test_cases = [
|
||||
(1, 257, 64, 8905, False, True),
|
||||
(1, 1024, 64, 8905, False, False),
|
||||
(1024, 1024, 64, 8905, False, False),
|
||||
(1024, 1024, 128, 8905, False, False),
|
||||
(1, 257, 64, 8905, False, True, True),
|
||||
(1, 1024, 64, 8905, False, False, True),
|
||||
(1024, 1024, 64, 8905, False, False, True),
|
||||
(1024, 1024, 128, 8905, False, False, True),
|
||||
(1024, 1024, 127, 8905, False, False, False),
|
||||
]
|
||||
|
||||
for k in test_cases:
|
||||
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
|
||||
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \
|
||||
expected_pass = k
|
||||
query = jnp.empty((4, sql_q, 4, head_dim))
|
||||
key = jnp.empty((4, sql_v, 4, head_dim))
|
||||
check_is_flash_attention(
|
||||
query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training)
|
||||
if expected_pass:
|
||||
check_is_flash_attention(
|
||||
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
|
||||
is_training)
|
||||
else:
|
||||
with self.assertRaises(NotImplementedError):
|
||||
check_is_flash_attention(
|
||||
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
|
||||
is_training)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user