diff --git a/tests/BUILD b/tests/BUILD index e5a0d71b6..8283b53ad 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -189,6 +189,9 @@ jax_test( jax_test( name = "pjit_test", srcs = ["pjit_test.py"], + backend_tags = { + "tpu": ["notsan"], # Times out under tsan. + }, pjrt_c_api_bypass = True, shard_count = { "cpu": 5, @@ -530,6 +533,12 @@ jax_test( jax_test( name = "pmap_test", srcs = ["pmap_test.py"], + backend_tags = { + "tpu": [ + "noasan", + "notsan", + ], # Times out under asan/tsan. + }, pjrt_c_api_bypass = True, shard_count = { "cpu": 30,