The newly added test class is failing, and blocking presubmits

Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53

PiperOrigin-RevId: 654842341
This commit is contained in:
Vladimir Belitskiy 2024-07-22 11:51:40 -07:00 committed by jax authors
parent 0d7531b4f1
commit d7b821b04d
5 changed files with 64 additions and 105 deletions

View File

@ -879,7 +879,6 @@ def _splash_attention_forward(
save_residuals: bool,
mask_function: MaskFunctionType | None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False
) -> SplashCustomReturnType:
num_q_heads, q_seq_len, head_dim = q.shape
bq, bkv = block_sizes.block_q, block_sizes.block_kv
@ -1111,7 +1110,6 @@ def _splash_attention_forward(
compiler_params=dict(mosaic=mosaic_params),
out_shape=out_shapes,
name=kernel_name,
interpret=interpret,
)(
fwd_mask_info.data_next,
fwd_mask_info.block_mask,
@ -1148,7 +1146,7 @@ def _splash_attention_forward(
return out
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
def _splash_attention_custom(
fwd_mask_info: mask_info_lib.MaskInfo,
dq_mask_info: mask_info_lib.MaskInfo | None,
@ -1164,7 +1162,6 @@ def _splash_attention_custom(
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False,
) -> SplashCustomReturnType:
# The forward function does not use the dq and dkv MaskInfos, it just forwards
# them to the backward function as residuals. This is a way to communicate
@ -1186,11 +1183,10 @@ def _splash_attention_custom(
mask_value=mask_value,
is_mqa=is_mqa,
block_sizes=block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=save_residuals,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret,
residual_checkpoint_name=residual_checkpoint_name,
mask_function=mask_function,
)
@ -1209,7 +1205,6 @@ def _splash_attention_fwd(
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False,
) -> tuple[
tuple[jax.Array],
SplashResidualsType,
@ -1217,7 +1212,7 @@ def _splash_attention_fwd(
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
out, (logsumexp,) = _splash_attention_forward( # pytype: disable=wrong-arg-types
out, (logsumexp,) = _splash_attention_forward(
fwd_mask_info,
q,
k,
@ -1226,11 +1221,10 @@ def _splash_attention_fwd(
mask_value=mask_value,
is_mqa=is_mqa,
block_sizes=block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=True,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret,
residual_checkpoint_name=residual_checkpoint_name,
mask_function=mask_function,
)
return out, (
q,
@ -1355,7 +1349,6 @@ def _splash_attention_bwd_dq(
k_layout: QKVLayout,
v_layout: QKVLayout,
mask_function: MaskFunctionType | None,
interpret: bool,
):
num_q_heads, q_seq_len, head_dim = q.shape
if is_mqa:
@ -1567,7 +1560,6 @@ def _splash_attention_bwd_dq(
out_shape=out_shapes,
compiler_params=dict(mosaic=mosaic_params),
name=kernel_name,
interpret=interpret,
)(
mask_info.data_next,
mask_info.block_mask,
@ -1797,7 +1789,6 @@ def _splash_attention_bwd_dkv(
k_layout: QKVLayout,
v_layout: QKVLayout,
mask_function: MaskFunctionType | None,
interpret: bool,
):
num_q_heads, q_seq_len, head_dim = q.shape
if is_mqa:
@ -2116,7 +2107,6 @@ def _splash_attention_bwd_dkv(
out_shape=out_shapes,
compiler_params=dict(mosaic=mosaic_params),
name=kernel_name,
interpret=interpret,
)(
mask_info.data_next,
mask_info.block_mask,
@ -2149,7 +2139,6 @@ def _splash_attention_bwd(
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
attn_logits_soft_cap: float | None,
interpret: bool,
res: SplashResidualsType,
do: jax.Array,
) -> tuple[
@ -2204,7 +2193,6 @@ def _splash_attention_bwd(
k_layout=block_sizes.k_layout,
v_layout=block_sizes.v_layout,
mask_function=mask_function,
interpret=interpret,
)
if not use_fused_bwd_kernel:
assert dq is None
@ -2226,7 +2214,6 @@ def _splash_attention_bwd(
k_layout=block_sizes.k_layout,
v_layout=block_sizes.v_layout,
mask_function=mask_function,
interpret=interpret,
)
# Match the signature of the fwd function.
assert dq is not None
@ -2254,7 +2241,6 @@ _splash_attention_custom.defvjp(_splash_attention_fwd, _splash_attention_bwd)
"attn_logits_soft_cap",
"residual_checkpoint_name",
"mask_function",
"interpret",
],
)
def _splash_attention(
@ -2273,7 +2259,6 @@ def _splash_attention(
attn_logits_soft_cap: float | None,
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
interpret: bool,
) -> SplashCustomReturnType:
return _splash_attention_custom(
fwd_mask_info,
@ -2290,7 +2275,6 @@ def _splash_attention(
attn_logits_soft_cap=attn_logits_soft_cap,
residual_checkpoint_name=residual_checkpoint_name,
mask_function=mask_function,
interpret=interpret,
)
@ -2394,7 +2378,6 @@ def _make_splash_attention(
head_shards: int,
q_seq_shards: int,
residual_checkpoint_name: str | None = None,
interpret: bool = False,
):
if len(mask.shape) != 3:
raise ValueError(f'Unexpected mask shape: {mask.shape}')
@ -2456,7 +2439,6 @@ def _make_splash_attention(
attn_logits_soft_cap=attn_logits_soft_cap,
residual_checkpoint_name=residual_checkpoint_name,
mask_function=mask_function_fwd,
interpret=interpret,
)

View File

@ -378,9 +378,10 @@ jax_test(
"tpu_splash_attention_kernel_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
shard_count = 24,
shard_count = 18,
tags = [
"noasan", # Times out.
"nomsan", # Times out.
@ -397,6 +398,7 @@ jax_test(
"tpu_splash_attention_mask_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
deps = [
@ -415,6 +417,7 @@ jax_test(
},
},
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
@ -448,6 +451,7 @@ jax_test(
},
},
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
from absl.testing import absltest
from absl.testing import parameterized
@ -35,26 +34,14 @@ config.parse_flags_with_absl()
@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
class DecodeAttentionTest(jtu.JaxTestCase):
def setUp(self):
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
if jtu.test_device_matches(["cpu", "gpu"]) and jax.config.x64_enabled:
self.skipTest("On CPU and GPU the test works only in 32-bit")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32" and not self.INTERPRET:
self.skipTest("Only works on non-Windows platforms")
if not jtu.is_cuda_compute_capability_at_least("8.0"):
self.skipTest("Fused attention only works on GPUs with capability >= sm80")
super().setUp()
class DecodeAttentionTest(PallasBaseTest):
INTERPRET = False
@parameterized.named_parameters(*[
(
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}",
@ -92,7 +79,7 @@ class DecodeAttentionTest(PallasBaseTest):
k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16)
v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16)
o = decode_attention.mqa(q, k, v, interpret=self.INTERPRET)
o = decode_attention.mqa(q, k, v)
o_ref = decode_attention.mqa_reference(q, k, v)
np.testing.assert_allclose(o, o_ref, atol=0.05)
@ -142,13 +129,10 @@ class DecodeAttentionTest(PallasBaseTest):
k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16
)
o = decode_attention.gqa(q, k, v, interpret=self.INTERPRET)
o = decode_attention.gqa(q, k, v)
o_ref = decode_attention.gqa_reference(q, k, v)
np.testing.assert_allclose(o, o_ref, atol=0.05)
class DecodeAttentionInterpreterTest(DecodeAttentionTest):
INTERPRET = True
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -114,14 +114,14 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
class PallasTest(jtu.JaxTestCase):
INTERPRET = False
def setUp(self):
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
if jtu.test_device_matches(["cpu", "gpu"]) and jax.config.x64_enabled:
self.skipTest("On CPU and GPU the test works only in 32-bit")
if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled:
self.skipTest("On GPU the test works only in 32-bit")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
@ -135,24 +135,28 @@ class PallasBaseTest(jtu.JaxTestCase):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
class FusedAttentionTest(PallasBaseTest):
class FusedAttentionTest(PallasTest):
def setUp(self):
super().setUp()
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not intended for TPU")
if jtu.test_device_matches(["cpu", "tpu"]):
self.skipTest("Works only on GPU")
@jtu.parameterized_filterable(
kwargs=[
dict(
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_heads,
head_dim=head_dim,
causal=causal,
use_fwd=use_fwd,
use_segment_ids=use_segment_ids,
kwargs=kwargs,
@parameterized.named_parameters(
*[
(
(
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}"
f"_{use_fwd=}_{use_segment_ids=}_{kwargs=}"
),
batch_size,
seq_len,
num_heads,
head_dim,
causal,
use_fwd,
use_segment_ids,
kwargs,
)
for (
batch_size,
@ -179,7 +183,6 @@ class FusedAttentionTest(PallasBaseTest):
)
def test_fused_attention_fwd(
self,
*,
batch_size,
seq_len,
num_heads,
@ -212,8 +215,7 @@ class FusedAttentionTest(PallasBaseTest):
def impl(q, k, v):
v, _ = jax.vjp(
functools.partial(
attention.mha, causal=causal, segment_ids=segment_ids,
interpret=self.INTERPRET, **kwargs
attention.mha, causal=causal, segment_ids=segment_ids, **kwargs
),
q,
k,
@ -223,22 +225,25 @@ class FusedAttentionTest(PallasBaseTest):
else:
impl = functools.partial(
attention.mha, causal=causal, segment_ids=segment_ids,
interpret=self.INTERPRET, **kwargs
attention.mha, causal=causal, segment_ids=segment_ids, **kwargs
)
o = impl(q, k, v)
o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal)
np.testing.assert_allclose(o, o_ref, atol=0.05)
@jtu.parameterized_filterable(
kwargs=[
dict(
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_heads,
head_dim=head_dim,
causal=causal,
use_segment_ids=use_segment_ids,
@parameterized.named_parameters(
*[
(
(
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_"
f"{use_segment_ids=}"
),
batch_size,
seq_len,
num_heads,
head_dim,
causal,
use_segment_ids,
)
for (
batch_size,
@ -258,7 +263,7 @@ class FusedAttentionTest(PallasBaseTest):
]
)
def test_fused_attention_bwd(
self, *, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
):
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
@ -278,8 +283,7 @@ class FusedAttentionTest(PallasBaseTest):
segment_ids = None
def f(q, k, v):
return attention.mha(q, k, v, segment_ids, causal=causal,
interpret=self.INTERPRET).sum()
return attention.mha(q, k, v, segment_ids, causal=causal).sum()
def f_ref(q, k, v):
return attention.mha_reference(q, k, v, segment_ids, causal=causal).sum()
@ -296,7 +300,7 @@ class FusedAttentionInterpreterTest(FusedAttentionTest):
INTERPRET = True
class FusedLayerNormTest(PallasBaseTest):
class FusedLayerNormTest(PallasTest):
def setUp(self):
super().setUp()
@ -344,7 +348,7 @@ class FusedLayerNormInterpreterTest(FusedLayerNormTest):
INTERPRET = True
class RmsNormTest(PallasBaseTest):
class RmsNormTest(PallasTest):
def setUp(self):
super().setUp()
@ -392,7 +396,7 @@ class RmsNormInterpreterTest(RmsNormTest):
INTERPRET = True
class SoftmaxTest(PallasBaseTest):
class SoftmaxTest(PallasTest):
def setUp(self):
super().setUp()

View File

@ -290,19 +290,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]:
return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0))
@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
class AttentionTest(jtu.JaxTestCase):
def setUp(self):
if not self.INTERPRET:
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only interpret mode supported on non-TPU")
# TODO(b/327487669): selectively re-enable tests that works on TPU v3.
if not jtu.is_device_tpu_at_least(4):
self.skipTest("Not supported on TPU generations <= 3")
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("On CPU the test works only in 32-bit")
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Need TPU devices")
# TODO(b/327487669): selectively re-enable tests that works on TPU v3.
if not jtu.is_device_tpu_at_least(4):
self.skipTest("Not supported on TPU generations <= 3")
super().setUp()
@ -316,7 +311,7 @@ class PallasBaseTest(jtu.JaxTestCase):
np.testing.assert_allclose(x, y, **kwargs)
class SplashAttentionTest(PallasBaseTest):
class SplashAttentionTest(AttentionTest):
@parameterized.product(
is_mqa=(False, True),
@ -360,7 +355,6 @@ class SplashAttentionTest(PallasBaseTest):
mask,
block_sizes=block_sizes,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
else:
attn_ref = splash.make_masked_mha_reference(mask)
@ -368,7 +362,6 @@ class SplashAttentionTest(PallasBaseTest):
mask,
block_sizes=block_sizes,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
o = attn(q, k, v, segment_ids)
o_ref = attn_ref(
@ -423,7 +416,6 @@ class SplashAttentionTest(PallasBaseTest):
block_sizes=block_sizes,
save_residuals=True,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
else:
attn_ref = splash.make_masked_mha_reference(mask)
@ -432,7 +424,6 @@ class SplashAttentionTest(PallasBaseTest):
block_sizes=block_sizes,
save_residuals=True,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
attn_ref = partial(
attn_ref,
@ -573,7 +564,6 @@ class SplashAttentionTest(PallasBaseTest):
block_sizes=block_sizes,
downcast_smem_data=downcast_smem_data,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
else:
attn_ref = splash.make_masked_mha_reference(mask, backward_impl="custom")
@ -582,7 +572,6 @@ class SplashAttentionTest(PallasBaseTest):
block_sizes=block_sizes,
downcast_smem_data=downcast_smem_data,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=self.INTERPRET,
)
o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids)
q32, k32, v32 = jax.tree.map(
@ -650,9 +639,5 @@ class SplashAttentionTest(PallasBaseTest):
self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2)
class SplashAttentionInterpreterTest(SplashAttentionTest):
INTERPRET = True
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())