diff --git a/tests/BUILD b/tests/BUILD index efdd8265d..f957810e9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -318,6 +318,9 @@ jax_test( jax_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], + backend_tags = { + "tpu": ["noasan"], # Test times out. + }, pjrt_c_api_bypass = True, shard_count = { "cpu": 40, @@ -449,6 +452,9 @@ jax_test( jax_test( name = "lax_vmap_test", srcs = ["lax_vmap_test.py"], + backend_tags = { + "tpu": ["noasan"], # Test times out. + }, shard_count = { "cpu": 40, "gpu": 40, @@ -606,6 +612,7 @@ jax_test( "tpu": 30, "iree": 30, }, + tags = ["noasan"], # Times out ) # TODO(b/199564969): remove once we always enable_custom_prng @@ -614,7 +621,10 @@ jax_test( srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], backend_tags = { - "cpu": ["noasan"], # Times out under asan. + "cpu": [ + "noasan", + "notsan", + ], # Times out under asan/tsan. }, main = "random_test.py", shard_count = { @@ -675,6 +685,10 @@ jax_test( "tpu": 40, "iree": 10, }, + tags = [ + "noasan", + "notsan", + ], # Times out ) jax_test( @@ -682,9 +696,7 @@ jax_test( srcs = ["sparse_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], backend_tags = { - "cpu": [ - "noasan", # Test times out under asan. - ], + "cpu": ["notsan"], # Times out "tpu": ["optonly"], }, shard_count = { @@ -693,6 +705,7 @@ jax_test( "tpu": 50, "iree": 10, }, + tags = ["noasan"], # Test times out under asan. deps = [ "//jax:experimental_sparse", ] + py_deps("scipy"),