mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
This commit is contained in:
parent
aa3254d723
commit
19dbff5326
51
tests/BUILD
51
tests/BUILD
@ -34,6 +34,7 @@ jax_generate_backend_suites()
|
||||
jax_multiplatform_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_test.py"],
|
||||
enable_configs = ["tpu_v3_2x2"],
|
||||
shard_count = 10,
|
||||
)
|
||||
|
||||
@ -70,6 +71,9 @@ jax_multiplatform_test(
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_2gpu",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = py_deps("tensorflow_core"),
|
||||
)
|
||||
@ -214,6 +218,14 @@ jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "memories_test",
|
||||
srcs = ["memories_test.py"],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"gpu_2gpu",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v5p_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
],
|
||||
shard_count = {
|
||||
"tpu": 5,
|
||||
},
|
||||
@ -234,6 +246,8 @@ jax_multiplatform_test(
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v4_2x2_shardy",
|
||||
"tpu_v3_2x2",
|
||||
"gpu_2gpu",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
@ -258,6 +272,11 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "shard_alike_test",
|
||||
srcs = ["shard_alike_test.py"],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
"tpu_v4_2x2",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
@ -298,6 +317,9 @@ jax_multiplatform_test(
|
||||
backend_tags = {
|
||||
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
||||
},
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
@ -644,6 +666,10 @@ jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "multibackend_test",
|
||||
srcs = ["multibackend_test.py"],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
"gpu_2gpu",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
@ -693,6 +719,10 @@ jax_multiplatform_test(
|
||||
"requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
||||
],
|
||||
},
|
||||
enable_configs = [
|
||||
"gpu_v100",
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 30,
|
||||
"gpu": 30,
|
||||
@ -1030,6 +1060,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "checkify_test",
|
||||
srcs = ["checkify_test.py"],
|
||||
enable_configs = ["tpu_v3_2x2"],
|
||||
shard_count = {
|
||||
"gpu": 2,
|
||||
"tpu": 4,
|
||||
@ -1187,8 +1218,11 @@ jax_multiplatform_test(
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
@ -1197,8 +1231,11 @@ jax_multiplatform_test(
|
||||
name = "debugging_primitives_test",
|
||||
srcs = ["debugging_primitives_test.py"],
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1208,6 +1245,11 @@ jax_multiplatform_test(
|
||||
backend_tags = {
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
enable_configs = [
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
@ -1218,8 +1260,11 @@ jax_multiplatform_test(
|
||||
name = "debugger_test",
|
||||
srcs = ["debugger_test.py"],
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
"cpu",
|
||||
"gpu_h100",
|
||||
"tpu_v2_1x1",
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v4_2x2",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -244,6 +244,9 @@ jax_multiplatform_test(
|
||||
"tpu_all_gather_test.py",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
enable_configs = [
|
||||
"tpu_v5e_4x2",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas_tpu_ops",
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
@ -277,6 +280,10 @@ jax_multiplatform_test(
|
||||
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
||||
args = ["--logtostderr"],
|
||||
enable_backends = ["tpu"],
|
||||
enable_configs = [
|
||||
"tpu_v5e",
|
||||
"tpu_v5p_1x1",
|
||||
],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
"//jax:pallas_tpu",
|
||||
@ -305,6 +312,12 @@ jax_multiplatform_test(
|
||||
name = "tpu_pallas_distributed_test",
|
||||
srcs = ["tpu_pallas_distributed_test.py"],
|
||||
enable_backends = ["tpu"],
|
||||
enable_configs = [
|
||||
"tpu_v5e_4x2",
|
||||
"tpu_v5p_2x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
"//jax:pallas_tpu",
|
||||
@ -316,6 +329,10 @@ jax_multiplatform_test(
|
||||
name = "tpu_pallas_pipeline_test",
|
||||
srcs = ["tpu_pallas_pipeline_test.py"],
|
||||
enable_backends = ["tpu"],
|
||||
enable_configs = [
|
||||
"tpu_v5e_4x2",
|
||||
"tpu_v5p_1x1",
|
||||
],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -333,7 +350,9 @@ jax_multiplatform_test(
|
||||
name = "tpu_pallas_async_test",
|
||||
srcs = ["tpu_pallas_async_test.py"],
|
||||
enable_backends = ["tpu"],
|
||||
tags = [
|
||||
enable_configs = [
|
||||
"tpu_v5e_4x2",
|
||||
"tpu_v5p_1x1",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas_tpu",
|
||||
|
Loading…
x
Reference in New Issue
Block a user