diff --git a/tests/BUILD b/tests/BUILD index 087372eea..0ca66be5f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,6 +34,7 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], + enable_configs = ["tpu_v3_2x2"], shard_count = 10, ) @@ -70,6 +71,9 @@ jax_multiplatform_test( "cpu", "gpu", ], + enable_configs = [ + "gpu_2gpu", + ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) @@ -214,6 +218,14 @@ jax_py_test( jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], + enable_configs = [ + "cpu", + "gpu_2gpu", + "tpu_v3_2x2", + "tpu_v4_2x2", + "tpu_v5p_2x2", + "tpu_v5e_4x2", + ], shard_count = { "tpu": 5, }, @@ -234,6 +246,8 @@ jax_multiplatform_test( "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", "tpu_v4_2x2_shardy", + "tpu_v3_2x2", + "gpu_2gpu", ], shard_count = { "cpu": 5, @@ -258,6 +272,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], + enable_configs = [ + "tpu_v3_2x2", + "tpu_v5e_4x2", + "tpu_v4_2x2", + ], deps = [ "//jax:experimental", ], @@ -298,6 +317,9 @@ jax_multiplatform_test( backend_tags = { "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, + enable_configs = [ + "tpu_v3_2x2", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -644,6 +666,10 @@ jax_py_test( jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], + enable_configs = [ + "tpu_v3_2x2", + "gpu_2gpu", + ], ) jax_multiplatform_test( @@ -693,6 +719,10 @@ jax_multiplatform_test( "requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit. ], }, + enable_configs = [ + "gpu_v100", + "tpu_v3_2x2", + ], shard_count = { "cpu": 30, "gpu": 30, @@ -1030,6 +1060,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], + enable_configs = ["tpu_v3_2x2"], shard_count = { "gpu": 2, "tpu": 4, @@ -1187,8 +1218,11 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], tags = ["multiaccelerator"], ) @@ -1197,8 +1231,11 @@ jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], ) @@ -1208,6 +1245,11 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + enable_configs = [ + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -1218,8 +1260,11 @@ jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index db1f09ae0..b5af90272 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -244,6 +244,9 @@ jax_multiplatform_test( "tpu_all_gather_test.py", ], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + ], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), @@ -277,6 +280,10 @@ jax_multiplatform_test( # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e", + "tpu_v5p_1x1", + ], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -305,6 +312,12 @@ jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_2x2", + "tpu_v4_2x2", + "tpu_v3_2x2", + ], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -316,6 +329,10 @@ jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_1x1", + ], shard_count = 5, tags = [ "noasan", # Times out. @@ -333,7 +350,9 @@ jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], - tags = [ + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_1x1", ], deps = [ "//jax:pallas_tpu",