diff --git a/tests/BUILD b/tests/BUILD index 0d8494df4..16ce56189 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -572,6 +572,11 @@ jax_test( jax_test( name = "pmap_test", srcs = ["pmap_test.py"], + backend_tags = { + "tpu": [ + "noasan", # Times out under asan. + ], + }, pjrt_c_api_bypass = True, shard_count = { "cpu": 30, @@ -771,6 +776,11 @@ jax_test( name = "sparsify_test", srcs = ["sparsify_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], + backend_tags = { + "tpu": [ + "noasan", # Times out under asan. + ], + }, shard_count = { "cpu": 5, "gpu": 20,