mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration. * Rename "_2gpu" test variants to "x2" variants, since this is more succinct. This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run. PiperOrigin-RevId: 715468944
This commit is contained in:
parent
f1b894d14a
commit
f122f17b27
18
tests/BUILD
18
tests/BUILD
@ -71,7 +71,7 @@ jax_multiplatform_test(
|
|||||||
"gpu",
|
"gpu",
|
||||||
],
|
],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_2gpu",
|
"gpu_p100x2",
|
||||||
],
|
],
|
||||||
tags = ["multiaccelerator"],
|
tags = ["multiaccelerator"],
|
||||||
deps = py_deps("tensorflow_core"),
|
deps = py_deps("tensorflow_core"),
|
||||||
@ -226,12 +226,12 @@ jax_multiplatform_test(
|
|||||||
srcs = ["memories_test.py"],
|
srcs = ["memories_test.py"],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"cpu",
|
"cpu",
|
||||||
"gpu_2gpu",
|
"gpu_p100x2",
|
||||||
"tpu_v3_2x2",
|
"tpu_v3_2x2",
|
||||||
"tpu_v4_2x2",
|
"tpu_v4_2x2",
|
||||||
"tpu_v5p_2x2",
|
"tpu_v5p_2x2",
|
||||||
"tpu_v5e_4x2",
|
"tpu_v5e_4x2",
|
||||||
"gpu_2gpu_shardy",
|
"gpu_p100x2_shardy",
|
||||||
"tpu_v5e_4x2_shardy",
|
"tpu_v5e_4x2_shardy",
|
||||||
],
|
],
|
||||||
shard_count = {
|
shard_count = {
|
||||||
@ -250,10 +250,10 @@ jax_multiplatform_test(
|
|||||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||||
},
|
},
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_2gpu_shardy",
|
"gpu_p100x2_shardy",
|
||||||
"tpu_v3_2x2_shardy",
|
"tpu_v3_2x2_shardy",
|
||||||
"tpu_v3_2x2",
|
"tpu_v3_2x2",
|
||||||
"gpu_2gpu",
|
"gpu_p100x2",
|
||||||
],
|
],
|
||||||
shard_count = {
|
shard_count = {
|
||||||
"cpu": 5,
|
"cpu": 5,
|
||||||
@ -316,7 +316,7 @@ jax_multiplatform_test(
|
|||||||
srcs = ["mock_gpu_test.py"],
|
srcs = ["mock_gpu_test.py"],
|
||||||
enable_backends = ["gpu"],
|
enable_backends = ["gpu"],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_2gpu_shardy",
|
"gpu_p100x2_shardy",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"config-cuda-only",
|
"config-cuda-only",
|
||||||
@ -701,7 +701,7 @@ jax_multiplatform_test(
|
|||||||
srcs = ["multibackend_test.py"],
|
srcs = ["multibackend_test.py"],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"tpu_v3_2x2",
|
"tpu_v3_2x2",
|
||||||
"gpu_2gpu",
|
"gpu_p100x2",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1300,7 +1300,7 @@ jax_multiplatform_test(
|
|||||||
"tpu_v3_2x2",
|
"tpu_v3_2x2",
|
||||||
"tpu_v4_2x2",
|
"tpu_v4_2x2",
|
||||||
"tpu_v3_2x2_shardy",
|
"tpu_v3_2x2_shardy",
|
||||||
"gpu_2gpu_shardy",
|
"gpu_p100x2_shardy",
|
||||||
],
|
],
|
||||||
tags = ["multiaccelerator"],
|
tags = ["multiaccelerator"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1364,7 +1364,7 @@ jax_multiplatform_test(
|
|||||||
name = "shard_map_test",
|
name = "shard_map_test",
|
||||||
srcs = ["shard_map_test.py"],
|
srcs = ["shard_map_test.py"],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_2gpu_shardy",
|
"gpu_p100x2_shardy",
|
||||||
"tpu_v3_2x2_shardy",
|
"tpu_v3_2x2_shardy",
|
||||||
],
|
],
|
||||||
shard_count = {
|
shard_count = {
|
||||||
|
@ -35,7 +35,7 @@ jax_multiplatform_test(
|
|||||||
enable_backends = [],
|
enable_backends = [],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_h100",
|
"gpu_h100",
|
||||||
"gpu_h100_2gpu",
|
"gpu_h100x2",
|
||||||
],
|
],
|
||||||
shard_count = 8,
|
shard_count = 8,
|
||||||
tags = ["multiaccelerator"],
|
tags = ["multiaccelerator"],
|
||||||
|
@ -77,7 +77,7 @@ jax_multiplatform_test(
|
|||||||
],
|
],
|
||||||
disable_configs = [
|
disable_configs = [
|
||||||
"gpu_v100",
|
"gpu_v100",
|
||||||
"gpu_x32",
|
"gpu_v100_x32",
|
||||||
"gpu_a100",
|
"gpu_a100",
|
||||||
"gpu_p100",
|
"gpu_p100",
|
||||||
"gpu_p100_x32",
|
"gpu_p100_x32",
|
||||||
@ -97,7 +97,7 @@ jax_multiplatform_test(
|
|||||||
],
|
],
|
||||||
disable_configs = [
|
disable_configs = [
|
||||||
"gpu_v100",
|
"gpu_v100",
|
||||||
"gpu_x32",
|
"gpu_v100_x32",
|
||||||
"gpu_p100",
|
"gpu_p100",
|
||||||
"gpu_p100_x32",
|
"gpu_p100_x32",
|
||||||
],
|
],
|
||||||
@ -222,11 +222,11 @@ jax_multiplatform_test(
|
|||||||
name = "pallas_shape_poly_test",
|
name = "pallas_shape_poly_test",
|
||||||
srcs = ["pallas_shape_poly_test.py"],
|
srcs = ["pallas_shape_poly_test.py"],
|
||||||
disable_configs = [
|
disable_configs = [
|
||||||
"gpu_x32",
|
|
||||||
"gpu_h100",
|
"gpu_h100",
|
||||||
"gpu_p100",
|
"gpu_p100",
|
||||||
"gpu_p100_x32",
|
"gpu_p100_x32",
|
||||||
"gpu_pjrt_c_api",
|
"gpu_v100_x32",
|
||||||
|
"gpu_p100_pjrt_c_api",
|
||||||
],
|
],
|
||||||
enable_configs = [
|
enable_configs = [
|
||||||
"gpu_a100_x32",
|
"gpu_a100_x32",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user