diff --git a/tests/BUILD b/tests/BUILD index f957810e9..90ae35558 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -606,6 +606,9 @@ jax_test( jax_test( name = "random_test", srcs = ["random_test.py"], + backend_tags = { + "cpu": ["notsan"], # Times out + }, shard_count = { "cpu": 30, "gpu": 30, @@ -705,7 +708,10 @@ jax_test( "tpu": 50, "iree": 10, }, - tags = ["noasan"], # Test times out under asan. + tags = [ + "noasan", + "notsan", + ], # Test times out under asan/tsan. deps = [ "//jax:experimental_sparse", ] + py_deps("scipy"),