Rename test configs to include GPU variants more consistently.

* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
This commit is contained in:
Peter Hawkins 2025-01-14 11:55:02 -08:00 committed by jax authors
parent f1b894d14a
commit f122f17b27
3 changed files with 14 additions and 14 deletions

View File

@ -71,7 +71,7 @@ jax_multiplatform_test(
"gpu", "gpu",
], ],
enable_configs = [ enable_configs = [
"gpu_2gpu", "gpu_p100x2",
], ],
tags = ["multiaccelerator"], tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"), deps = py_deps("tensorflow_core"),
@ -226,12 +226,12 @@ jax_multiplatform_test(
srcs = ["memories_test.py"], srcs = ["memories_test.py"],
enable_configs = [ enable_configs = [
"cpu", "cpu",
"gpu_2gpu", "gpu_p100x2",
"tpu_v3_2x2", "tpu_v3_2x2",
"tpu_v4_2x2", "tpu_v4_2x2",
"tpu_v5p_2x2", "tpu_v5p_2x2",
"tpu_v5e_4x2", "tpu_v5e_4x2",
"gpu_2gpu_shardy", "gpu_p100x2_shardy",
"tpu_v5e_4x2_shardy", "tpu_v5e_4x2_shardy",
], ],
shard_count = { shard_count = {
@ -250,10 +250,10 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
}, },
enable_configs = [ enable_configs = [
"gpu_2gpu_shardy", "gpu_p100x2_shardy",
"tpu_v3_2x2_shardy", "tpu_v3_2x2_shardy",
"tpu_v3_2x2", "tpu_v3_2x2",
"gpu_2gpu", "gpu_p100x2",
], ],
shard_count = { shard_count = {
"cpu": 5, "cpu": 5,
@ -316,7 +316,7 @@ jax_multiplatform_test(
srcs = ["mock_gpu_test.py"], srcs = ["mock_gpu_test.py"],
enable_backends = ["gpu"], enable_backends = ["gpu"],
enable_configs = [ enable_configs = [
"gpu_2gpu_shardy", "gpu_p100x2_shardy",
], ],
tags = [ tags = [
"config-cuda-only", "config-cuda-only",
@ -701,7 +701,7 @@ jax_multiplatform_test(
srcs = ["multibackend_test.py"], srcs = ["multibackend_test.py"],
enable_configs = [ enable_configs = [
"tpu_v3_2x2", "tpu_v3_2x2",
"gpu_2gpu", "gpu_p100x2",
], ],
) )
@ -1300,7 +1300,7 @@ jax_multiplatform_test(
"tpu_v3_2x2", "tpu_v3_2x2",
"tpu_v4_2x2", "tpu_v4_2x2",
"tpu_v3_2x2_shardy", "tpu_v3_2x2_shardy",
"gpu_2gpu_shardy", "gpu_p100x2_shardy",
], ],
tags = ["multiaccelerator"], tags = ["multiaccelerator"],
deps = [ deps = [
@ -1364,7 +1364,7 @@ jax_multiplatform_test(
name = "shard_map_test", name = "shard_map_test",
srcs = ["shard_map_test.py"], srcs = ["shard_map_test.py"],
enable_configs = [ enable_configs = [
"gpu_2gpu_shardy", "gpu_p100x2_shardy",
"tpu_v3_2x2_shardy", "tpu_v3_2x2_shardy",
], ],
shard_count = { shard_count = {

View File

@ -35,7 +35,7 @@ jax_multiplatform_test(
enable_backends = [], enable_backends = [],
enable_configs = [ enable_configs = [
"gpu_h100", "gpu_h100",
"gpu_h100_2gpu", "gpu_h100x2",
], ],
shard_count = 8, shard_count = 8,
tags = ["multiaccelerator"], tags = ["multiaccelerator"],

View File

@ -77,7 +77,7 @@ jax_multiplatform_test(
], ],
disable_configs = [ disable_configs = [
"gpu_v100", "gpu_v100",
"gpu_x32", "gpu_v100_x32",
"gpu_a100", "gpu_a100",
"gpu_p100", "gpu_p100",
"gpu_p100_x32", "gpu_p100_x32",
@ -97,7 +97,7 @@ jax_multiplatform_test(
], ],
disable_configs = [ disable_configs = [
"gpu_v100", "gpu_v100",
"gpu_x32", "gpu_v100_x32",
"gpu_p100", "gpu_p100",
"gpu_p100_x32", "gpu_p100_x32",
], ],
@ -222,11 +222,11 @@ jax_multiplatform_test(
name = "pallas_shape_poly_test", name = "pallas_shape_poly_test",
srcs = ["pallas_shape_poly_test.py"], srcs = ["pallas_shape_poly_test.py"],
disable_configs = [ disable_configs = [
"gpu_x32",
"gpu_h100", "gpu_h100",
"gpu_p100", "gpu_p100",
"gpu_p100_x32", "gpu_p100_x32",
"gpu_pjrt_c_api", "gpu_v100_x32",
"gpu_p100_pjrt_c_api",
], ],
enable_configs = [ enable_configs = [
"gpu_a100_x32", "gpu_a100_x32",