mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
The newly added test class is failing, and blocking presubmits
Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53 PiperOrigin-RevId: 654842341
This commit is contained in:
parent
0d7531b4f1
commit
d7b821b04d
@ -879,7 +879,6 @@ def _splash_attention_forward(
|
||||
save_residuals: bool,
|
||||
mask_function: MaskFunctionType | None,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
interpret: bool = False
|
||||
) -> SplashCustomReturnType:
|
||||
num_q_heads, q_seq_len, head_dim = q.shape
|
||||
bq, bkv = block_sizes.block_q, block_sizes.block_kv
|
||||
@ -1111,7 +1110,6 @@ def _splash_attention_forward(
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
out_shape=out_shapes,
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
)(
|
||||
fwd_mask_info.data_next,
|
||||
fwd_mask_info.block_mask,
|
||||
@ -1148,7 +1146,7 @@ def _splash_attention_forward(
|
||||
return out
|
||||
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14))
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
|
||||
def _splash_attention_custom(
|
||||
fwd_mask_info: mask_info_lib.MaskInfo,
|
||||
dq_mask_info: mask_info_lib.MaskInfo | None,
|
||||
@ -1164,7 +1162,6 @@ def _splash_attention_custom(
|
||||
residual_checkpoint_name: str | None,
|
||||
mask_function: MaskFunctionType | None,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
interpret: bool = False,
|
||||
) -> SplashCustomReturnType:
|
||||
# The forward function does not use the dq and dkv MaskInfos, it just forwards
|
||||
# them to the backward function as residuals. This is a way to communicate
|
||||
@ -1186,11 +1183,10 @@ def _splash_attention_custom(
|
||||
mask_value=mask_value,
|
||||
is_mqa=is_mqa,
|
||||
block_sizes=block_sizes,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
save_residuals=save_residuals,
|
||||
mask_function=mask_function,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=interpret,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
mask_function=mask_function,
|
||||
)
|
||||
|
||||
|
||||
@ -1209,7 +1205,6 @@ def _splash_attention_fwd(
|
||||
residual_checkpoint_name: str | None,
|
||||
mask_function: MaskFunctionType | None,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
interpret: bool = False,
|
||||
) -> tuple[
|
||||
tuple[jax.Array],
|
||||
SplashResidualsType,
|
||||
@ -1217,7 +1212,7 @@ def _splash_attention_fwd(
|
||||
if save_residuals:
|
||||
raise NotImplementedError("Higher-order AD not supported")
|
||||
|
||||
out, (logsumexp,) = _splash_attention_forward( # pytype: disable=wrong-arg-types
|
||||
out, (logsumexp,) = _splash_attention_forward(
|
||||
fwd_mask_info,
|
||||
q,
|
||||
k,
|
||||
@ -1226,11 +1221,10 @@ def _splash_attention_fwd(
|
||||
mask_value=mask_value,
|
||||
is_mqa=is_mqa,
|
||||
block_sizes=block_sizes,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
save_residuals=True,
|
||||
mask_function=mask_function,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=interpret,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
mask_function=mask_function,
|
||||
)
|
||||
return out, (
|
||||
q,
|
||||
@ -1355,7 +1349,6 @@ def _splash_attention_bwd_dq(
|
||||
k_layout: QKVLayout,
|
||||
v_layout: QKVLayout,
|
||||
mask_function: MaskFunctionType | None,
|
||||
interpret: bool,
|
||||
):
|
||||
num_q_heads, q_seq_len, head_dim = q.shape
|
||||
if is_mqa:
|
||||
@ -1567,7 +1560,6 @@ def _splash_attention_bwd_dq(
|
||||
out_shape=out_shapes,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
)(
|
||||
mask_info.data_next,
|
||||
mask_info.block_mask,
|
||||
@ -1797,7 +1789,6 @@ def _splash_attention_bwd_dkv(
|
||||
k_layout: QKVLayout,
|
||||
v_layout: QKVLayout,
|
||||
mask_function: MaskFunctionType | None,
|
||||
interpret: bool,
|
||||
):
|
||||
num_q_heads, q_seq_len, head_dim = q.shape
|
||||
if is_mqa:
|
||||
@ -2116,7 +2107,6 @@ def _splash_attention_bwd_dkv(
|
||||
out_shape=out_shapes,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
)(
|
||||
mask_info.data_next,
|
||||
mask_info.block_mask,
|
||||
@ -2149,7 +2139,6 @@ def _splash_attention_bwd(
|
||||
residual_checkpoint_name: str | None,
|
||||
mask_function: MaskFunctionType | None,
|
||||
attn_logits_soft_cap: float | None,
|
||||
interpret: bool,
|
||||
res: SplashResidualsType,
|
||||
do: jax.Array,
|
||||
) -> tuple[
|
||||
@ -2204,7 +2193,6 @@ def _splash_attention_bwd(
|
||||
k_layout=block_sizes.k_layout,
|
||||
v_layout=block_sizes.v_layout,
|
||||
mask_function=mask_function,
|
||||
interpret=interpret,
|
||||
)
|
||||
if not use_fused_bwd_kernel:
|
||||
assert dq is None
|
||||
@ -2226,7 +2214,6 @@ def _splash_attention_bwd(
|
||||
k_layout=block_sizes.k_layout,
|
||||
v_layout=block_sizes.v_layout,
|
||||
mask_function=mask_function,
|
||||
interpret=interpret,
|
||||
)
|
||||
# Match the signature of the fwd function.
|
||||
assert dq is not None
|
||||
@ -2254,7 +2241,6 @@ _splash_attention_custom.defvjp(_splash_attention_fwd, _splash_attention_bwd)
|
||||
"attn_logits_soft_cap",
|
||||
"residual_checkpoint_name",
|
||||
"mask_function",
|
||||
"interpret",
|
||||
],
|
||||
)
|
||||
def _splash_attention(
|
||||
@ -2273,7 +2259,6 @@ def _splash_attention(
|
||||
attn_logits_soft_cap: float | None,
|
||||
residual_checkpoint_name: str | None,
|
||||
mask_function: MaskFunctionType | None,
|
||||
interpret: bool,
|
||||
) -> SplashCustomReturnType:
|
||||
return _splash_attention_custom(
|
||||
fwd_mask_info,
|
||||
@ -2290,7 +2275,6 @@ def _splash_attention(
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
mask_function=mask_function,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
|
||||
@ -2394,7 +2378,6 @@ def _make_splash_attention(
|
||||
head_shards: int,
|
||||
q_seq_shards: int,
|
||||
residual_checkpoint_name: str | None = None,
|
||||
interpret: bool = False,
|
||||
):
|
||||
if len(mask.shape) != 3:
|
||||
raise ValueError(f'Unexpected mask shape: {mask.shape}')
|
||||
@ -2456,7 +2439,6 @@ def _make_splash_attention(
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
residual_checkpoint_name=residual_checkpoint_name,
|
||||
mask_function=mask_function_fwd,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
|
||||
|
@ -378,9 +378,10 @@ jax_test(
|
||||
"tpu_splash_attention_kernel_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
shard_count = 24,
|
||||
shard_count = 18,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
"nomsan", # Times out.
|
||||
@ -397,6 +398,7 @@ jax_test(
|
||||
"tpu_splash_attention_mask_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -415,6 +417,7 @@ jax_test(
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
@ -448,6 +451,7 @@ jax_test(
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -35,26 +34,14 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
class DecodeAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On CPU the test works only in interpret mode")
|
||||
if jtu.test_device_matches(["cpu", "gpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("On CPU and GPU the test works only in 32-bit")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on GPU with capability >= sm80")
|
||||
if sys.platform == "win32" and not self.INTERPRET:
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
self.skipTest("Fused attention only works on GPUs with capability >= sm80")
|
||||
|
||||
super().setUp()
|
||||
|
||||
|
||||
class DecodeAttentionTest(PallasBaseTest):
|
||||
INTERPRET = False
|
||||
|
||||
@parameterized.named_parameters(*[
|
||||
(
|
||||
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}",
|
||||
@ -92,7 +79,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16)
|
||||
v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16)
|
||||
|
||||
o = decode_attention.mqa(q, k, v, interpret=self.INTERPRET)
|
||||
o = decode_attention.mqa(q, k, v)
|
||||
o_ref = decode_attention.mqa_reference(q, k, v)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
|
||||
@ -142,13 +129,10 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16
|
||||
)
|
||||
|
||||
o = decode_attention.gqa(q, k, v, interpret=self.INTERPRET)
|
||||
o = decode_attention.gqa(q, k, v)
|
||||
o_ref = decode_attention.gqa_reference(q, k, v)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
|
||||
class DecodeAttentionInterpreterTest(DecodeAttentionTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -114,14 +114,14 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
|
||||
def setUp(self):
|
||||
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On CPU the test works only in interpret mode")
|
||||
if jtu.test_device_matches(["cpu", "gpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("On CPU and GPU the test works only in 32-bit")
|
||||
if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("On GPU the test works only in 32-bit")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on GPU with capability >= sm80")
|
||||
@ -135,24 +135,28 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
|
||||
class FusedAttentionTest(PallasBaseTest):
|
||||
class FusedAttentionTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not intended for TPU")
|
||||
if jtu.test_device_matches(["cpu", "tpu"]):
|
||||
self.skipTest("Works only on GPU")
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
causal=causal,
|
||||
use_fwd=use_fwd,
|
||||
use_segment_ids=use_segment_ids,
|
||||
kwargs=kwargs,
|
||||
@parameterized.named_parameters(
|
||||
*[
|
||||
(
|
||||
(
|
||||
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}"
|
||||
f"_{use_fwd=}_{use_segment_ids=}_{kwargs=}"
|
||||
),
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
use_fwd,
|
||||
use_segment_ids,
|
||||
kwargs,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
@ -179,7 +183,6 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
)
|
||||
def test_fused_attention_fwd(
|
||||
self,
|
||||
*,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
@ -212,8 +215,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
def impl(q, k, v):
|
||||
v, _ = jax.vjp(
|
||||
functools.partial(
|
||||
attention.mha, causal=causal, segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET, **kwargs
|
||||
attention.mha, causal=causal, segment_ids=segment_ids, **kwargs
|
||||
),
|
||||
q,
|
||||
k,
|
||||
@ -223,22 +225,25 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
|
||||
else:
|
||||
impl = functools.partial(
|
||||
attention.mha, causal=causal, segment_ids=segment_ids,
|
||||
interpret=self.INTERPRET, **kwargs
|
||||
attention.mha, causal=causal, segment_ids=segment_ids, **kwargs
|
||||
)
|
||||
o = impl(q, k, v)
|
||||
o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal)
|
||||
np.testing.assert_allclose(o, o_ref, atol=0.05)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
causal=causal,
|
||||
use_segment_ids=use_segment_ids,
|
||||
@parameterized.named_parameters(
|
||||
*[
|
||||
(
|
||||
(
|
||||
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_"
|
||||
f"{use_segment_ids=}"
|
||||
),
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
use_segment_ids,
|
||||
)
|
||||
for (
|
||||
batch_size,
|
||||
@ -258,7 +263,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
]
|
||||
)
|
||||
def test_fused_attention_bwd(
|
||||
self, *, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
|
||||
self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
|
||||
):
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
@ -278,8 +283,7 @@ class FusedAttentionTest(PallasBaseTest):
|
||||
segment_ids = None
|
||||
|
||||
def f(q, k, v):
|
||||
return attention.mha(q, k, v, segment_ids, causal=causal,
|
||||
interpret=self.INTERPRET).sum()
|
||||
return attention.mha(q, k, v, segment_ids, causal=causal).sum()
|
||||
|
||||
def f_ref(q, k, v):
|
||||
return attention.mha_reference(q, k, v, segment_ids, causal=causal).sum()
|
||||
@ -296,7 +300,7 @@ class FusedAttentionInterpreterTest(FusedAttentionTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class FusedLayerNormTest(PallasBaseTest):
|
||||
class FusedLayerNormTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -344,7 +348,7 @@ class FusedLayerNormInterpreterTest(FusedLayerNormTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class RmsNormTest(PallasBaseTest):
|
||||
class RmsNormTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -392,7 +396,7 @@ class RmsNormInterpreterTest(RmsNormTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class SoftmaxTest(PallasBaseTest):
|
||||
class SoftmaxTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -290,19 +290,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]:
|
||||
return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0))
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
class AttentionTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not self.INTERPRET:
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only interpret mode supported on non-TPU")
|
||||
# TODO(b/327487669): selectively re-enable tests that works on TPU v3.
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Not supported on TPU generations <= 3")
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("On CPU the test works only in 32-bit")
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Need TPU devices")
|
||||
# TODO(b/327487669): selectively re-enable tests that works on TPU v3.
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Not supported on TPU generations <= 3")
|
||||
|
||||
super().setUp()
|
||||
|
||||
@ -316,7 +311,7 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(x, y, **kwargs)
|
||||
|
||||
|
||||
class SplashAttentionTest(PallasBaseTest):
|
||||
class SplashAttentionTest(AttentionTest):
|
||||
|
||||
@parameterized.product(
|
||||
is_mqa=(False, True),
|
||||
@ -360,7 +355,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
mask,
|
||||
block_sizes=block_sizes,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
else:
|
||||
attn_ref = splash.make_masked_mha_reference(mask)
|
||||
@ -368,7 +362,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
mask,
|
||||
block_sizes=block_sizes,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o = attn(q, k, v, segment_ids)
|
||||
o_ref = attn_ref(
|
||||
@ -423,7 +416,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
block_sizes=block_sizes,
|
||||
save_residuals=True,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
else:
|
||||
attn_ref = splash.make_masked_mha_reference(mask)
|
||||
@ -432,7 +424,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
block_sizes=block_sizes,
|
||||
save_residuals=True,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
attn_ref = partial(
|
||||
attn_ref,
|
||||
@ -573,7 +564,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
block_sizes=block_sizes,
|
||||
downcast_smem_data=downcast_smem_data,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
else:
|
||||
attn_ref = splash.make_masked_mha_reference(mask, backward_impl="custom")
|
||||
@ -582,7 +572,6 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
block_sizes=block_sizes,
|
||||
downcast_smem_data=downcast_smem_data,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
interpret=self.INTERPRET,
|
||||
)
|
||||
o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids)
|
||||
q32, k32, v32 = jax.tree.map(
|
||||
@ -650,9 +639,5 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2)
|
||||
|
||||
|
||||
class SplashAttentionInterpreterTest(SplashAttentionTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user