diff --git a/tests/BUILD b/tests/BUILD index a673a5aef..cfa2083b2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -71,7 +71,7 @@ jax_multiplatform_test( "gpu", ], enable_configs = [ - "gpu_2gpu", + "gpu_p100x2", ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), @@ -226,12 +226,12 @@ jax_multiplatform_test( srcs = ["memories_test.py"], enable_configs = [ "cpu", - "gpu_2gpu", + "gpu_p100x2", "tpu_v3_2x2", "tpu_v4_2x2", "tpu_v5p_2x2", "tpu_v5e_4x2", - "gpu_2gpu_shardy", + "gpu_p100x2_shardy", "tpu_v5e_4x2_shardy", ], shard_count = { @@ -250,10 +250,10 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu_2gpu_shardy", + "gpu_p100x2_shardy", "tpu_v3_2x2_shardy", "tpu_v3_2x2", - "gpu_2gpu", + "gpu_p100x2", ], shard_count = { "cpu": 5, @@ -316,7 +316,7 @@ jax_multiplatform_test( srcs = ["mock_gpu_test.py"], enable_backends = ["gpu"], enable_configs = [ - "gpu_2gpu_shardy", + "gpu_p100x2_shardy", ], tags = [ "config-cuda-only", @@ -701,7 +701,7 @@ jax_multiplatform_test( srcs = ["multibackend_test.py"], enable_configs = [ "tpu_v3_2x2", - "gpu_2gpu", + "gpu_p100x2", ], ) @@ -1300,7 +1300,7 @@ jax_multiplatform_test( "tpu_v3_2x2", "tpu_v4_2x2", "tpu_v3_2x2_shardy", - "gpu_2gpu_shardy", + "gpu_p100x2_shardy", ], tags = ["multiaccelerator"], deps = [ @@ -1364,7 +1364,7 @@ jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], enable_configs = [ - "gpu_2gpu_shardy", + "gpu_p100x2_shardy", "tpu_v3_2x2_shardy", ], shard_count = { diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 4d6d0b9ca..a52d36e8f 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -35,7 +35,7 @@ jax_multiplatform_test( enable_backends = [], enable_configs = [ "gpu_h100", - "gpu_h100_2gpu", + "gpu_h100x2", ], shard_count = 8, tags = ["multiaccelerator"], diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index bd589d880..44f11e951 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -77,7 +77,7 @@ jax_multiplatform_test( ], disable_configs = [ "gpu_v100", - "gpu_x32", + "gpu_v100_x32", "gpu_a100", "gpu_p100", "gpu_p100_x32", @@ -97,7 +97,7 @@ jax_multiplatform_test( ], disable_configs = [ "gpu_v100", - "gpu_x32", + "gpu_v100_x32", "gpu_p100", "gpu_p100_x32", ], @@ -222,11 +222,11 @@ jax_multiplatform_test( name = "pallas_shape_poly_test", srcs = ["pallas_shape_poly_test.py"], disable_configs = [ - "gpu_x32", "gpu_h100", "gpu_p100", "gpu_p100_x32", - "gpu_pjrt_c_api", + "gpu_v100_x32", + "gpu_p100_pjrt_c_api", ], enable_configs = [ "gpu_a100_x32",