Move additional CI enabled/disabled configurations into jax BUILD files.

PiperOrigin-RevId: 684457403
This commit is contained in:
Peter Hawkins 2024-10-10 08:40:51 -07:00 committed by jax authors
parent aa3254d723
commit 19dbff5326
2 changed files with 68 additions and 4 deletions

View File

@ -34,6 +34,7 @@ jax_generate_backend_suites()
jax_multiplatform_test(
name = "api_test",
srcs = ["api_test.py"],
enable_configs = ["tpu_v3_2x2"],
shard_count = 10,
)
@ -70,6 +71,9 @@ jax_multiplatform_test(
"cpu",
"gpu",
],
enable_configs = [
"gpu_2gpu",
],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
)
@ -214,6 +218,14 @@ jax_py_test(
jax_multiplatform_test(
name = "memories_test",
srcs = ["memories_test.py"],
enable_configs = [
"cpu",
"gpu_2gpu",
"tpu_v3_2x2",
"tpu_v4_2x2",
"tpu_v5p_2x2",
"tpu_v5e_4x2",
],
shard_count = {
"tpu": 5,
},
@ -234,6 +246,8 @@ jax_multiplatform_test(
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
"tpu_v3_2x2",
"gpu_2gpu",
],
shard_count = {
"cpu": 5,
@ -258,6 +272,11 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "shard_alike_test",
srcs = ["shard_alike_test.py"],
enable_configs = [
"tpu_v3_2x2",
"tpu_v5e_4x2",
"tpu_v4_2x2",
],
deps = [
"//jax:experimental",
],
@ -298,6 +317,9 @@ jax_multiplatform_test(
backend_tags = {
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
},
enable_configs = [
"tpu_v3_2x2",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
@ -644,6 +666,10 @@ jax_py_test(
jax_multiplatform_test(
name = "multibackend_test",
srcs = ["multibackend_test.py"],
enable_configs = [
"tpu_v3_2x2",
"gpu_2gpu",
],
)
jax_multiplatform_test(
@ -693,6 +719,10 @@ jax_multiplatform_test(
"requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit.
],
},
enable_configs = [
"gpu_v100",
"tpu_v3_2x2",
],
shard_count = {
"cpu": 30,
"gpu": 30,
@ -1030,6 +1060,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "checkify_test",
srcs = ["checkify_test.py"],
enable_configs = ["tpu_v3_2x2"],
shard_count = {
"gpu": 2,
"tpu": 4,
@ -1187,8 +1218,11 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
tags = ["multiaccelerator"],
)
@ -1197,8 +1231,11 @@ jax_multiplatform_test(
name = "debugging_primitives_test",
srcs = ["debugging_primitives_test.py"],
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
)
@ -1208,6 +1245,11 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
@ -1218,8 +1260,11 @@ jax_multiplatform_test(
name = "debugger_test",
srcs = ["debugger_test.py"],
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
)

View File

@ -244,6 +244,9 @@ jax_multiplatform_test(
"tpu_all_gather_test.py",
],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
],
deps = [
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
@ -277,6 +280,10 @@ jax_multiplatform_test(
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
args = ["--logtostderr"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e",
"tpu_v5p_1x1",
],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
@ -305,6 +312,12 @@ jax_multiplatform_test(
name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_2x2",
"tpu_v4_2x2",
"tpu_v3_2x2",
],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
@ -316,6 +329,10 @@ jax_multiplatform_test(
name = "tpu_pallas_pipeline_test",
srcs = ["tpu_pallas_pipeline_test.py"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_1x1",
],
shard_count = 5,
tags = [
"noasan", # Times out.
@ -333,7 +350,9 @@ jax_multiplatform_test(
name = "tpu_pallas_async_test",
srcs = ["tpu_pallas_async_test.py"],
enable_backends = ["tpu"],
tags = [
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_1x1",
],
deps = [
"//jax:pallas_tpu",