mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration. * Rename "_2gpu" test variants to "x2" variants, since this is more succinct. This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run. PiperOrigin-RevId: 715468944
This commit is contained in:
parent
f1b894d14a
commit
f122f17b27
18
tests/BUILD
18
tests/BUILD
@ -71,7 +71,7 @@ jax_multiplatform_test(
|
||||
"gpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_2gpu",
|
||||
"gpu_p100x2",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = py_deps("tensorflow_core"),
|
||||
@ -226,12 +226,12 @@ jax_multiplatform_test(
|
||||
srcs = ["memories_test.py"],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"gpu_2gpu",
|
||||
"gpu_p100x2",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v5p_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
"gpu_2gpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
"tpu_v5e_4x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
@ -250,10 +250,10 @@ jax_multiplatform_test(
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
enable_configs = [
|
||||
"gpu_2gpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v3_2x2",
|
||||
"gpu_2gpu",
|
||||
"gpu_p100x2",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
@ -316,7 +316,7 @@ jax_multiplatform_test(
|
||||
srcs = ["mock_gpu_test.py"],
|
||||
enable_backends = ["gpu"],
|
||||
enable_configs = [
|
||||
"gpu_2gpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
@ -701,7 +701,7 @@ jax_multiplatform_test(
|
||||
srcs = ["multibackend_test.py"],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
"gpu_2gpu",
|
||||
"gpu_p100x2",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1300,7 +1300,7 @@ jax_multiplatform_test(
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
@ -1364,7 +1364,7 @@ jax_multiplatform_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
enable_configs = [
|
||||
"gpu_2gpu_shardy",
|
||||
"gpu_p100x2_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
|
@ -35,7 +35,7 @@ jax_multiplatform_test(
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
"gpu_h100_2gpu",
|
||||
"gpu_h100x2",
|
||||
],
|
||||
shard_count = 8,
|
||||
tags = ["multiaccelerator"],
|
||||
|
@ -77,7 +77,7 @@ jax_multiplatform_test(
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_v100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
@ -97,7 +97,7 @@ jax_multiplatform_test(
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_v100_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
],
|
||||
@ -222,11 +222,11 @@ jax_multiplatform_test(
|
||||
name = "pallas_shape_poly_test",
|
||||
srcs = ["pallas_shape_poly_test.py"],
|
||||
disable_configs = [
|
||||
"gpu_x32",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
"gpu_v100_x32",
|
||||
"gpu_p100_pjrt_c_api",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
|
Loading…
x
Reference in New Issue
Block a user