Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.

Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).

PiperOrigin-RevId: 459320212
This commit is contained in:
Peter Hawkins 2022-07-06 12:51:07 -07:00 committed by jax authors
parent 354c684873
commit 95e79332c0
2 changed files with 9 additions and 2 deletions

View File

@ -77,12 +77,12 @@ flags.DEFINE_bool(
)
flags.DEFINE_string(
'test_targets', '',
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
flags.DEFINE_string(
'exclude_test_targets', '',
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)

View File

@ -159,6 +159,7 @@ jax_test(
"cpu": 10,
"tpu": 4,
},
tags = ["multiaccelerator"],
deps = [
"//jax:maps",
],
@ -167,6 +168,7 @@ jax_test(
jax_test(
name = "pjit_test",
srcs = ["pjit_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
@ -175,6 +177,7 @@ jax_test(
jax_test(
name = "global_device_array_test",
srcs = ["global_device_array_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
@ -183,6 +186,7 @@ jax_test(
jax_test(
name = "array_test",
srcs = ["array_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
@ -195,6 +199,7 @@ jax_test(
"gpu",
"cpu",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
@ -502,6 +507,7 @@ jax_test(
"gpu": 5,
"tpu": 5,
},
tags = ["multiaccelerator"],
deps = [
":lax_test_lib",
":lax_vmap_test_lib",
@ -634,6 +640,7 @@ jax_test(
"tpu",
"iree",
],
tags = ["multiaccelerator"],
deps = ["//jax:experimental"],
)