diff --git a/tests/BUILD b/tests/BUILD index de924bdca..9660f85d2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -417,17 +417,14 @@ jax_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { - "cpu": [ - "noasan", - "notsan", - ], # Test times out. - "tpu": ["noasan"], # Test times out. + "cpu": ["notsan"], # Test times out. }, shard_count = { "cpu": 40, "gpu": 40, "tpu": 40, }, + tags = ["noasan"], # Test times out on all backends ) jax_test( @@ -687,6 +684,9 @@ jax_test( name = "pmap_test", srcs = ["pmap_test.py"], backend_tags = { + "gpu": [ + "noasan", # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + ], "tpu": [ "noasan", # Times out under asan. ], @@ -743,6 +743,10 @@ jax_test( jax_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], + backend_tags = { + # PyTorch leaks dlpack metadata https://github.com/pytorch/pytorch/issues/117058 + "gpu": ["noasan"], + }, disable_backends = ["tpu"], tags = ["not_build:arm"], deps = py_deps("torch"),