Adding more tests for multi-head attention

This commit is contained in:
Gleb Pobudzey 2024-12-09 20:49:06 +00:00
parent 66b900540a
commit e1e174fbc4

View File

@ -148,40 +148,16 @@ class FusedAttentionTest(PallasBaseTest):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not intended for TPU")
@jtu.parameterized_filterable(
kwargs=[
dict(
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_heads,
head_dim=head_dim,
causal=causal,
use_fwd=use_fwd,
use_segment_ids=use_segment_ids,
kwargs=kwargs,
)
for (
batch_size,
seq_len,
num_heads,
head_dim,
causal,
use_fwd,
use_segment_ids,
kwargs,
) in [
(1, 384, 1, 64, False, False, True, {}),
(1, 384, 1, 64, False, False, False, {}),
(2, 384, 2, 64, False, False, True, {}),
(1, 384, 1, 64, True, False, True, {}),
# (2, 384, 2, 64, True, False, True, {}), # TODO(sharadmv): Investigate.
(1, 384, 8, 64, True, True, True, {}),
(1, 384, 8, 64, True, True, False, {}),
(2, 384, 8, 64, True, True, True, {}),
# regression test: https://github.com/jax-ml/jax/pull/17314
(1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}),
]
]
@jtu.sample_product(
batch_size=(1, 2),
seq_len=(128, 384),
num_heads=(1, 2, 8),
head_dim=(32, 64, 128),
block_q=(64, 128),
block_k=(64, 128),
causal=(True, False),
use_fwd=(True, False),
use_segment_ids=(True, False),
)
def test_fused_attention_fwd(
self,
@ -190,10 +166,11 @@ class FusedAttentionTest(PallasBaseTest):
seq_len,
num_heads,
head_dim,
block_q,
block_k,
causal,
use_fwd,
use_segment_ids,
kwargs,
):
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
@ -218,8 +195,12 @@ class FusedAttentionTest(PallasBaseTest):
def impl(q, k, v):
v, _ = jax.vjp(
functools.partial(
attention.mha, causal=causal, segment_ids=segment_ids,
interpret=self.INTERPRET, **kwargs
attention.mha,
block_q=block_q,
block_k=block_k,
causal=causal,
segment_ids=segment_ids,
interpret=self.INTERPRET,
),
q,
k,
@ -229,42 +210,34 @@ class FusedAttentionTest(PallasBaseTest):
else:
impl = functools.partial(
attention.mha, causal=causal, segment_ids=segment_ids,
interpret=self.INTERPRET, **kwargs
attention.mha,
block_q=block_q,
block_k=block_k,
causal=causal,
segment_ids=segment_ids,
interpret=self.INTERPRET,
)
o = impl(q, k, v)
o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal)
np.testing.assert_allclose(o, o_ref, atol=0.05)
@jtu.parameterized_filterable(
kwargs=[
dict(
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_heads,
head_dim=head_dim,
causal=causal,
use_segment_ids=use_segment_ids,
)
for (
batch_size,
seq_len,
num_heads,
head_dim,
causal,
use_segment_ids,
) in [
(1, 384, 1, 32, False, True),
(1, 384, 1, 32, False, False),
(2, 384, 2, 32, False, True),
(2, 384, 2, 32, False, False),
(1, 384, 1, 32, True, True),
(2, 384, 2, 32, True, True),
]
]
@jtu.sample_product(
batch_size=(1, 2),
seq_len=(128, 384),
num_heads=(1, 2, 4),
head_dim=(32,),
causal=(True, False),
use_segment_ids=(True, False),
)
def test_fused_attention_bwd(
self, *, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
self,
*,
batch_size,
seq_len,
num_heads,
head_dim,
causal,
use_segment_ids,
):
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(