diff --git a/tests/BUILD b/tests/BUILD index 64b41fc58..1ce103b32 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -555,15 +555,6 @@ jax_test( jax_test( name = "pmap_test", srcs = ["pmap_test.py"], - backend_tags = { - "tpu": [ - "noasan", # Times out. - "nomsan", # Times out. - "nodebug", # Times out. - "notsan", # Times out. - ], - "cpu": ["notsan"], # Times out - }, pjrt_c_api_bypass = True, shard_count = { "cpu": 30, @@ -694,11 +685,13 @@ jax_test( "cpu": [ "noasan", # Test times out under asan. ], + # TPU test times out under asan/msan/tsan (b/260710050) "tpu": [ "noasan", + "nomsan", "notsan", "optonly", - ], # Test times out under asan/tsan. + ], }, shard_count = { "cpu": 40,