diff --git a/tests/BUILD b/tests/BUILD index b5c1cf530..53455b7b8 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -117,7 +117,10 @@ jax_test( name = "fft_test", srcs = ["fft_test.py"], backend_tags = { - "tpu": ["notsan"], # Times out on TPU with tsan. + "tpu": [ + "noasan", + "notsan", + ], # Times out on TPU with asan/tsan. }, shard_count = { "tpu": 20, @@ -555,6 +558,7 @@ jax_test( "nodebug", # Times out. "notsan", # Times out. ], + "cpu": ["notsan"], # Times out }, pjrt_c_api_bypass = True, shard_count = { @@ -650,6 +654,12 @@ jax_test( jax_test( name = "scipy_fft_test", srcs = ["scipy_fft_test.py"], + backend_tags = { + "tpu": [ + "noasan", + "notsan", + ], # Times out on TPU with asan/tsan. + }, shard_count = 4, )