mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Adding more tests for multi-head attention
This commit is contained in:
parent
66b900540a
commit
e1e174fbc4
@ -148,40 +148,16 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not intended for TPU")
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
causal=causal,
|
||||
use_fwd=use_fwd,
|
||||
use_segment_ids=use_segment_ids,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
use_fwd,
|
||||
use_segment_ids,
|
||||
kwargs,
|
||||
) in [
|
||||
(1, 384, 1, 64, False, False, True, {}),
|
||||
(1, 384, 1, 64, False, False, False, {}),
|
||||
(2, 384, 2, 64, False, False, True, {}),
|
||||
(1, 384, 1, 64, True, False, True, {}),
|
||||
# (2, 384, 2, 64, True, False, True, {}), # TODO(sharadmv): Investigate.
|
||||
(1, 384, 8, 64, True, True, True, {}),
|
||||
(1, 384, 8, 64, True, True, False, {}),
|
||||
(2, 384, 8, 64, True, True, True, {}),
|
||||
# regression test: https://github.com/jax-ml/jax/pull/17314
|
||||
(1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}),
|
||||
]
|
||||
]
|
||||
@jtu.sample_product(
|
||||
batch_size=(1, 2),
|
||||
seq_len=(128, 384),
|
||||
num_heads=(1, 2, 8),
|
||||
head_dim=(32, 64, 128),
|
||||
block_q=(64, 128),
|
||||
block_k=(64, 128),
|
||||
causal=(True, False),
|
||||
use_fwd=(True, False),
|
||||
use_segment_ids=(True, False),
|
||||
)
|
||||
def test_fused_attention_fwd(
|
||||
self,
|
||||
@ -190,10 +166,11 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
block_q,
|
||||
block_k,
|
||||
causal,
|
||||
use_fwd,
|
||||
use_segment_ids,
|
||||
kwargs,
|
||||
):
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
@ -218,8 +195,12 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
def impl(q, k, v):
|
||||
v, _ = jax.vjp(
|
||||
functools.partial(
|
||||
attention.mha, causal=causal, segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET, **kwargs
|
||||
attention.mha,
|
||||
block_q=block_q,
|
||||
block_k=block_k,
|
||||
causal=causal,
|
||||
segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET,
|
||||
),
|
||||
q,
|
||||
k,
|
||||
@ -229,42 +210,34 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
|
||||
else:
|
||||
impl = functools.partial(
|
||||
attention.mha, causal=causal, segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET, **kwargs
|
||||
attention.mha,
|
||||
block_q=block_q,
|
||||
block_k=block_k,
|
||||
causal=causal,
|
||||
segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o = impl(q, k, v)
|
||||
o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
causal=causal,
|
||||
use_segment_ids=use_segment_ids,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
use_segment_ids,
|
||||
) in [
|
||||
(1, 384, 1, 32, False, True),
|
||||
(1, 384, 1, 32, False, False),
|
||||
(2, 384, 2, 32, False, True),
|
||||
(2, 384, 2, 32, False, False),
|
||||
(1, 384, 1, 32, True, True),
|
||||
(2, 384, 2, 32, True, True),
|
||||
]
|
||||
]
|
||||
@jtu.sample_product(
|
||||
batch_size=(1, 2),
|
||||
seq_len=(128, 384),
|
||||
num_heads=(1, 2, 4),
|
||||
head_dim=(32,),
|
||||
causal=(True, False),
|
||||
use_segment_ids=(True, False),
|
||||
)
|
||||
def test_fused_attention_bwd(
|
||||
self, *, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
|
||||
self,
|
||||
*,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
use_segment_ids,
|
||||
):
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
|
Loading…
x
Reference in New Issue
Block a user