diff --git a/tests/BUILD b/tests/BUILD index 1ce103b32..21d8657a3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -169,6 +169,9 @@ py_test( jax_test( name = "xmap_test", srcs = ["xmap_test.py"], + backend_tags = { + "tpu": ["noasan"], # Times out. + }, pjrt_c_api_bypass = True, shard_count = { "cpu": 10,