diff --git a/tests/BUILD b/tests/BUILD index 9660f85d2..c020de74a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -214,6 +214,9 @@ jax_test( name = "xmap_test", srcs = ["xmap_test.py"], backend_tags = { + "gpu": [ + "noasan", # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + ], "tpu": [ "noasan", # Times out. "nomsan", # Times out. @@ -244,6 +247,7 @@ jax_test( srcs = ["pjit_test.py"], backend_tags = { "tpu": ["notsan"], # Times out under tsan. + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, shard_count = { "cpu": 5, @@ -273,6 +277,9 @@ jax_test( jax_test( name = "pgle_test", srcs = ["pgle_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, disable_backends = [ "cpu", "tpu", @@ -1191,6 +1198,9 @@ jax_test( jax_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, enable_configs = [ "gpu", "cpu", @@ -1210,6 +1220,9 @@ jax_test( jax_test( name = "python_callback_test", srcs = ["python_callback_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, tags = ["multiaccelerator"], deps = [ "//jax:experimental", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index cad4c103f..50fb868c0 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_test", "py_deps", ) +load("@rules_python//python:defs.bzl", "py_test") licenses(["notice"]) @@ -31,6 +31,9 @@ jax_test( srcs = [ "pallas_test.py", ], + backend_tags = { + "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 + }, config_tags_overrides = { "gpu_x32": { "ondemand": False, # Include in presubmit.