Disable Pallas vmap test that is very slow under tsan.

PiperOrigin-RevId: 652505878
This commit is contained in:
Peter Hawkins 2024-07-15 09:27:43 -07:00 committed by jax authors
parent 9c72e67711
commit a1f69713f5

View File

@ -21,6 +21,7 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
from absl.testing import absltest
import jax
from jax import random
from jax._src.lib import xla_extension
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.pallas_call import _trace_to_jaxpr
@ -188,6 +189,9 @@ class PallasCallVmapTest(PallasTest):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_small_large_vmap(self):
if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]):
self.skipTest("Test is very slow under TSAN")
# Catches https://github.com/google/jax/issues/18361
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
@ -206,6 +210,9 @@ class PallasCallVmapTest(PallasTest):
np.testing.assert_allclose(out, out_ref)
def test_small_small_large_vmap(self):
if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]):
self.skipTest("Test is very slow under TSAN")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))