mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU). PiperOrigin-RevId: 459320212
This commit is contained in:
parent
354c684873
commit
95e79332c0
@ -77,12 +77,12 @@ flags.DEFINE_bool(
|
|||||||
)
|
)
|
||||||
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
'test_targets', '',
|
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
||||||
'Regular expression specifying which tests to run, called via re.search on '
|
'Regular expression specifying which tests to run, called via re.search on '
|
||||||
'the test name. If empty or unspecified, run all tests.'
|
'the test name. If empty or unspecified, run all tests.'
|
||||||
)
|
)
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
'exclude_test_targets', '',
|
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
||||||
'Regular expression specifying which tests NOT to run, called via re.search '
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
||||||
'on the test name. If empty or unspecified, run all tests.'
|
'on the test name. If empty or unspecified, run all tests.'
|
||||||
)
|
)
|
||||||
|
@ -159,6 +159,7 @@ jax_test(
|
|||||||
"cpu": 10,
|
"cpu": 10,
|
||||||
"tpu": 4,
|
"tpu": 4,
|
||||||
},
|
},
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:maps",
|
"//jax:maps",
|
||||||
],
|
],
|
||||||
@ -167,6 +168,7 @@ jax_test(
|
|||||||
jax_test(
|
jax_test(
|
||||||
name = "pjit_test",
|
name = "pjit_test",
|
||||||
srcs = ["pjit_test.py"],
|
srcs = ["pjit_test.py"],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:experimental",
|
"//jax:experimental",
|
||||||
],
|
],
|
||||||
@ -175,6 +177,7 @@ jax_test(
|
|||||||
jax_test(
|
jax_test(
|
||||||
name = "global_device_array_test",
|
name = "global_device_array_test",
|
||||||
srcs = ["global_device_array_test.py"],
|
srcs = ["global_device_array_test.py"],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:experimental",
|
"//jax:experimental",
|
||||||
],
|
],
|
||||||
@ -183,6 +186,7 @@ jax_test(
|
|||||||
jax_test(
|
jax_test(
|
||||||
name = "array_test",
|
name = "array_test",
|
||||||
srcs = ["array_test.py"],
|
srcs = ["array_test.py"],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:experimental",
|
"//jax:experimental",
|
||||||
],
|
],
|
||||||
@ -195,6 +199,7 @@ jax_test(
|
|||||||
"gpu",
|
"gpu",
|
||||||
"cpu",
|
"cpu",
|
||||||
],
|
],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:experimental",
|
"//jax:experimental",
|
||||||
],
|
],
|
||||||
@ -502,6 +507,7 @@ jax_test(
|
|||||||
"gpu": 5,
|
"gpu": 5,
|
||||||
"tpu": 5,
|
"tpu": 5,
|
||||||
},
|
},
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
":lax_test_lib",
|
":lax_test_lib",
|
||||||
":lax_vmap_test_lib",
|
":lax_vmap_test_lib",
|
||||||
@ -634,6 +640,7 @@ jax_test(
|
|||||||
"tpu",
|
"tpu",
|
||||||
"iree",
|
"iree",
|
||||||
],
|
],
|
||||||
|
tags = ["multiaccelerator"],
|
||||||
deps = ["//jax:experimental"],
|
deps = ["//jax:experimental"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user