Replace disable_backends with enable_backends on jax_multiplatform_test.

Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
This commit is contained in:
Peter Hawkins 2024-09-27 06:14:50 -07:00 committed by jax authors
parent 5740ab3b02
commit 26632fd344
7 changed files with 76 additions and 235 deletions

View File

@ -28,25 +28,11 @@ package(
jax_generate_backend_suites()
DISABLED_BACKENDS = [
"cpu",
"tpu",
]
DISABLED_CONFIGS = [
"gpu_v100",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_x32",
"gpu_pjrt_c_api",
]
jax_multiplatform_test(
name = "matmul_bench",
srcs = ["matmul_bench.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
tags = ["notap"],
deps = [
"//jax:mosaic_gpu",

View File

@ -32,10 +32,7 @@ jax_multiplatform_test(
name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"],
data = [":foo"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
tags = ["notap"],
deps = [
"//jax:extend",

View File

@ -13,7 +13,7 @@
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps")
licenses(["notice"])
@ -48,18 +48,17 @@ py_library(
],
)
jax_py_test(
jax_multiplatform_test(
name = "run_matmul",
srcs = ["matmul.py"],
enable_backends = [],
enable_configs = ["gpu_h100"],
main = "matmul.py",
tags = [
"manual",
"notap",
"requires-gpu-sm90-only",
],
deps = [
"//jax",
"//jax:mosaic_gpu",
"//learning/brain/research/jax:gpu_support",
] + py_deps("numpy"),
)

View File

@ -231,15 +231,22 @@ def jax_multiplatform_test(
shard_count = None,
deps = [],
data = [],
disable_backends = None, # buildifier: disable=unused-variable
enable_backends = None,
backend_variant_args = {}, # buildifier: disable=unused-variable
backend_tags = {}, # buildifier: disable=unused-variable
disable_configs = None, # buildifier: disable=unused-variable
enable_configs = None, # buildifier: disable=unused-variable
enable_configs = [],
config_tags_overrides = None, # buildifier: disable=unused-variable
tags = [],
main = None,
pjrt_c_api_bypass = False): # buildifier: disable=unused-variable
# enable_configs and disable_configs do not do anything in OSS, only in Google's CI.
# The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is
# as follows:
# 1. `enable_backends` is applied first, enabling all test configs for the given backends.
# 2. `disable_configs` is applied second, disabling the named test configs.
# 3. `enable_configs` is applied last, enabling the named test configs.
if main == None:
if len(srcs) == 1:
main = srcs[0]
@ -256,7 +263,7 @@ def jax_multiplatform_test(
"--jax_platform_name=" + backend,
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if disable_backends and backend in disable_backends:
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
if backend == "gpu":
test_tags += tf_cuda_tests_tags()

View File

@ -66,7 +66,10 @@ jax_py_test(
jax_multiplatform_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
enable_backends = [
"cpu",
"gpu",
],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
)
@ -160,10 +163,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "gpu_memory_flags_test_no_preallocation",
srcs = ["gpu_memory_flags_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {
"XLA_PYTHON_CLIENT_PREALLOCATE": "0",
},
@ -173,10 +173,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "gpu_memory_flags_test",
srcs = ["gpu_memory_flags_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {
"XLA_PYTHON_CLIENT_PREALLOCATE": "1",
},
@ -273,10 +270,7 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"},
tags = [
"config-cuda-only",
@ -290,10 +284,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "mock_gpu_test",
srcs = ["mock_gpu_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
tags = [
"config-cuda-only",
],
@ -556,11 +547,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "lax_metal_test",
srcs = ["lax_metal_test.py"],
disable_backends = [
"cpu",
"gpu",
"tpu",
],
enable_backends = ["metal"],
tags = ["notap"],
deps = [
"//jax:internal_test_util",
@ -649,10 +636,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)
jax_py_test(
@ -672,10 +656,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "multi_device_test",
srcs = ["multi_device_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)
jax_multiplatform_test(
@ -734,10 +715,7 @@ jax_multiplatform_test(
name = "polynomial_test",
srcs = ["polynomial_test.py"],
# No implementation of nonsymmetric Eigendecomposition.
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
shard_count = {
"cpu": 10,
},
@ -753,25 +731,18 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "heap_profiler_test",
srcs = ["heap_profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)
jax_multiplatform_test(
name = "profiler_test",
srcs = ["profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)
jax_multiplatform_test(
name = "pytorch_interoperability_test",
srcs = ["pytorch_interoperability_test.py"],
disable_backends = ["tpu"],
# The following cases are disabled because they time out in Google's CI, mostly because the
# CUDA kernels in Torch take a very long time to compile.
disable_configs = [
@ -779,6 +750,10 @@ jax_multiplatform_test(
"gpu_a100", # Pytorch A100 build times out in Google's CI.
"gpu_h100", # Pytorch H100 build times out in Google's CI.
],
enable_backends = [
"cpu",
"gpu",
],
tags = [
"not_build:arm",
# TODO(b/355237462): Re-enable once MSAN issue is addressed.
@ -1019,16 +994,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "sparse_nm_test",
srcs = ["sparse_nm_test.py"],
config_tags_overrides = {
"gpu_a100": {
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"cpu",
"gpu",
"tpu",
],
enable_backends = [],
enable_configs = [
"gpu_a100",
"gpu_h100",
@ -1386,13 +1352,10 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "experimental_rnn_test",
srcs = ["experimental_rnn_test.py"],
disable_backends = [
"tpu",
"cpu",
],
disable_configs = [
"gpu_a100", # Numerical precision problems.
],
enable_backends = ["gpu"],
shard_count = 15,
deps = [
"//jax:rnn",
@ -1505,10 +1468,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "fused_attention_stablehlo_test",
srcs = ["fused_attention_stablehlo_test.py"],
disable_backends = [
"tpu",
"cpu",
],
enable_backends = ["gpu"],
shard_count = {
"gpu": 4,
},
@ -1542,10 +1502,7 @@ jax_py_test(
jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
enable_configs = [
"gpu_a100",
"gpu_h100",

View File

@ -28,31 +28,16 @@ package(
jax_generate_backend_suites()
DISABLED_BACKENDS = [
"cpu",
"tpu",
]
DISABLED_CONFIGS = [
"gpu_a100",
"gpu_a100_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
"gpu_v100",
"gpu_x32",
]
jax_multiplatform_test(
name = "gpu_test",
srcs = ["gpu_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = [
"gpu_h100",
"gpu_h100_2gpu",
],
shard_count = 4,
tags = ["multiaccelerator"],
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
@ -61,8 +46,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "matmul_test",
srcs = ["matmul_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
shard_count = 5,
deps = [
"//jax:mosaic_gpu",
@ -73,8 +58,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "flash_attention",
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
tags = ["notap"],
deps = [
@ -85,8 +70,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "flash_attention_test",
srcs = ["flash_attention_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
deps = [
"//jax:mosaic_gpu",
"//jax/experimental/mosaic/gpu/examples:flash_attention",

View File

@ -38,11 +38,9 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_p100",
"gpu_p100_x32",
enable_backends = [
"cpu",
"tpu",
],
enable_configs = [
"gpu_a100_x32",
@ -75,9 +73,6 @@ jax_multiplatform_test(
"gpu_p100_x32",
"gpu_h100",
],
shard_count = {
"tpu": 1,
},
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
@ -130,8 +125,9 @@ jax_multiplatform_test(
srcs = [
"indexing_test.py",
],
disable_backends = [
"gpu",
enable_backends = [
"cpu",
"tpu",
],
tags = [
"noasan", # Times out.
@ -154,14 +150,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
],
enable_backends = ["cpu"],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
@ -186,19 +175,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_a100",
"gpu_a100_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_h100",
],
enable_backends = [],
enable_configs = [
"gpu_h100_x32",
],
@ -220,15 +197,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
enable_backends = ["cpu"],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
@ -251,15 +220,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
enable_backends = ["cpu"],
enable_configs = [
"gpu_a100_x32",
],
@ -303,10 +264,7 @@ jax_multiplatform_test(
srcs = [
"pallas_error_handling_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
@ -321,10 +279,7 @@ jax_multiplatform_test(
srcs = [
"tpu_all_gather_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
deps = [
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
@ -335,10 +290,7 @@ jax_multiplatform_test(
srcs = [
"tpu_gmm_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
shard_count = 50,
tags = [
"noasan", # Times out.
@ -360,10 +312,7 @@ jax_multiplatform_test(
srcs = ["tpu_pallas_test.py"],
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
args = ["--logtostderr"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
@ -376,8 +325,9 @@ jax_multiplatform_test(
srcs = [
"tpu_ops_test.py",
],
disable_backends = [
"gpu",
enable_backends = [
"cpu",
"tpu",
],
deps = [
"//jax:pallas",
@ -390,10 +340,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
@ -404,10 +351,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_pallas_pipeline_test",
srcs = ["tpu_pallas_pipeline_test.py"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
shard_count = 5,
tags = [
"noasan", # Times out.
@ -424,10 +368,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_pallas_async_test",
srcs = ["tpu_pallas_async_test.py"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
tags = [
],
deps = [
@ -438,10 +379,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_pallas_mesh_test",
srcs = ["tpu_pallas_mesh_test.py"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
tags = [
"noasan",
"nomsan",
@ -458,10 +396,7 @@ jax_multiplatform_test(
srcs = [
"tpu_pallas_random_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
@ -474,10 +409,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_paged_attention_kernel_test",
srcs = ["tpu_paged_attention_kernel_test.py"],
disable_backends = [
"cpu",
"gpu",
],
enable_backends = ["tpu"],
shard_count = 5,
tags = [
"noasan", # Times out.
@ -494,10 +426,7 @@ jax_multiplatform_test(
srcs = [
"tpu_splash_attention_kernel_test.py",
],
disable_backends = [
"gpu",
"cpu",
],
enable_backends = ["tpu"],
shard_count = 24,
tags = [
"noasan", # Times out.
@ -514,8 +443,9 @@ jax_multiplatform_test(
srcs = [
"tpu_splash_attention_mask_test.py",
],
disable_backends = [
"gpu",
enable_backends = [
"cpu",
"tpu",
],
deps = [
"//jax:pallas_tpu_ops",
@ -532,17 +462,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"tpu",
],
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_a100",
"gpu_h100",
],
enable_backends = ["cpu"],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
@ -565,17 +485,7 @@ jax_multiplatform_test(
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"tpu",
],
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
],
enable_backends = ["cpu"],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",