Increase bazel sharding of GPU tests.

Reduces the maximum time for some test shards to avoid flaky timeouts.
This commit is contained in:
Peter Hawkins 2022-07-11 13:30:44 +00:00
parent b666f665ec
commit 64e0b5d801

View File

@ -35,7 +35,7 @@ jax_generate_backend_suites()
jax_test(
name = "api_test",
srcs = ["api_test.py"],
shard_count = 5,
shard_count = 10,
)
jax_test(
@ -53,6 +53,9 @@ jax_test(
jax_test(
name = "batching_test",
srcs = ["batching_test.py"],
shard_count = {
"gpu": 5,
},
)
jax_test(
@ -66,6 +69,7 @@ jax_test(
srcs = ["core_test.py"],
shard_count = {
"cpu": 5,
"gpu": 10,
},
)
@ -157,6 +161,7 @@ jax_test(
srcs = ["xmap_test.py"],
shard_count = {
"cpu": 10,
"gpu": 4,
"tpu": 4,
},
tags = ["multiaccelerator"],
@ -210,7 +215,7 @@ jax_test(
srcs = ["image_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
@ -276,6 +281,7 @@ jax_test(
srcs = ["jet_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
},
deps = [
"//jax:jet",
@ -288,7 +294,7 @@ jax_test(
srcs = ["lax_control_flow_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
@ -361,7 +367,7 @@ jax_test(
},
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 40,
"tpu": 10,
"iree": 10,
},
@ -446,7 +452,7 @@ jax_test(
},
shard_count = {
"cpu": 20,
"gpu": 20,
"gpu": 40,
"tpu": 10,
"iree": 20,
},
@ -504,7 +510,7 @@ jax_test(
srcs = ["pmap_test.py"],
shard_count = {
"cpu": 5,
"gpu": 5,
"gpu": 10,
"tpu": 5,
},
tags = ["multiaccelerator"],
@ -579,7 +585,7 @@ jax_test(
main = "random_test.py",
shard_count = {
"cpu": 30,
"gpu": 20,
"gpu": 40,
"tpu": 20,
"iree": 20,
},
@ -618,6 +624,7 @@ jax_test(
], # Test times out under asan/tsan.
},
shard_count = {
"gpu": 10,
"tpu": 5,
},
)
@ -627,7 +634,7 @@ jax_test(
srcs = ["scipy_stats_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"gpu": 20,
"tpu": 10,
"iree": 10,
},
@ -655,7 +662,7 @@ jax_test(
},
shard_count = {
"cpu": 10,
"gpu": 20,
"gpu": 40,
"tpu": 10,
"iree": 10,
},
@ -668,6 +675,9 @@ jax_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],
args = ["--jax_bcoo_cusparse_lowering=true"],
shard_count = {
"gpu": 20,
},
deps = [
"//jax:experimental_sparse",
],
@ -787,6 +797,9 @@ jax_test(
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",
shard_count = {
"gpu": 5,
},
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",