Rename test configs to include GPU variants more consistently.

* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
This commit is contained in:
Peter Hawkins 2025-01-14 11:55:02 -08:00 committed by jax authors
parent f1b894d14a
commit f122f17b27
3 changed files with 14 additions and 14 deletions

View File

@ -71,7 +71,7 @@ jax_multiplatform_test(
"gpu",
],
enable_configs = [
"gpu_2gpu",
"gpu_p100x2",
],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
@ -226,12 +226,12 @@ jax_multiplatform_test(
srcs = ["memories_test.py"],
enable_configs = [
"cpu",
"gpu_2gpu",
"gpu_p100x2",
"tpu_v3_2x2",
"tpu_v4_2x2",
"tpu_v5p_2x2",
"tpu_v5e_4x2",
"gpu_2gpu_shardy",
"gpu_p100x2_shardy",
"tpu_v5e_4x2_shardy",
],
shard_count = {
@ -250,10 +250,10 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"gpu_2gpu_shardy",
"gpu_p100x2_shardy",
"tpu_v3_2x2_shardy",
"tpu_v3_2x2",
"gpu_2gpu",
"gpu_p100x2",
],
shard_count = {
"cpu": 5,
@ -316,7 +316,7 @@ jax_multiplatform_test(
srcs = ["mock_gpu_test.py"],
enable_backends = ["gpu"],
enable_configs = [
"gpu_2gpu_shardy",
"gpu_p100x2_shardy",
],
tags = [
"config-cuda-only",
@ -701,7 +701,7 @@ jax_multiplatform_test(
srcs = ["multibackend_test.py"],
enable_configs = [
"tpu_v3_2x2",
"gpu_2gpu",
"gpu_p100x2",
],
)
@ -1300,7 +1300,7 @@ jax_multiplatform_test(
"tpu_v3_2x2",
"tpu_v4_2x2",
"tpu_v3_2x2_shardy",
"gpu_2gpu_shardy",
"gpu_p100x2_shardy",
],
tags = ["multiaccelerator"],
deps = [
@ -1364,7 +1364,7 @@ jax_multiplatform_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
enable_configs = [
"gpu_2gpu_shardy",
"gpu_p100x2_shardy",
"tpu_v3_2x2_shardy",
],
shard_count = {

View File

@ -35,7 +35,7 @@ jax_multiplatform_test(
enable_backends = [],
enable_configs = [
"gpu_h100",
"gpu_h100_2gpu",
"gpu_h100x2",
],
shard_count = 8,
tags = ["multiaccelerator"],

View File

@ -77,7 +77,7 @@ jax_multiplatform_test(
],
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_v100_x32",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
@ -97,7 +97,7 @@ jax_multiplatform_test(
],
disable_configs = [
"gpu_v100",
"gpu_x32",
"gpu_v100_x32",
"gpu_p100",
"gpu_p100_x32",
],
@ -222,11 +222,11 @@ jax_multiplatform_test(
name = "pallas_shape_poly_test",
srcs = ["pallas_shape_poly_test.py"],
disable_configs = [
"gpu_x32",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
"gpu_v100_x32",
"gpu_p100_pjrt_c_api",
],
enable_configs = [
"gpu_a100_x32",