diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 323bf0357..c3e26de55 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -77,12 +77,12 @@ flags.DEFINE_bool( ) flags.DEFINE_string( - 'test_targets', '', + 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.' ) flags.DEFINE_string( - 'exclude_test_targets', '', + 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.' ) diff --git a/tests/BUILD b/tests/BUILD index 2c0431811..6b38531b9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -159,6 +159,7 @@ jax_test( "cpu": 10, "tpu": 4, }, + tags = ["multiaccelerator"], deps = [ "//jax:maps", ], @@ -167,6 +168,7 @@ jax_test( jax_test( name = "pjit_test", srcs = ["pjit_test.py"], + tags = ["multiaccelerator"], deps = [ "//jax:experimental", ], @@ -175,6 +177,7 @@ jax_test( jax_test( name = "global_device_array_test", srcs = ["global_device_array_test.py"], + tags = ["multiaccelerator"], deps = [ "//jax:experimental", ], @@ -183,6 +186,7 @@ jax_test( jax_test( name = "array_test", srcs = ["array_test.py"], + tags = ["multiaccelerator"], deps = [ "//jax:experimental", ], @@ -195,6 +199,7 @@ jax_test( "gpu", "cpu", ], + tags = ["multiaccelerator"], deps = [ "//jax:experimental", ], @@ -502,6 +507,7 @@ jax_test( "gpu": 5, "tpu": 5, }, + tags = ["multiaccelerator"], deps = [ ":lax_test_lib", ":lax_vmap_test_lib", @@ -634,6 +640,7 @@ jax_test( "tpu", "iree", ], + tags = ["multiaccelerator"], deps = ["//jax:experimental"], )