mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
This commit is contained in:
parent
b666f665ec
commit
64e0b5d801
31
tests/BUILD
31
tests/BUILD
@ -35,7 +35,7 @@ jax_generate_backend_suites()
|
||||
jax_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_test.py"],
|
||||
shard_count = 5,
|
||||
shard_count = 10,
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -53,6 +53,9 @@ jax_test(
|
||||
jax_test(
|
||||
name = "batching_test",
|
||||
srcs = ["batching_test.py"],
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -66,6 +69,7 @@ jax_test(
|
||||
srcs = ["core_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
"gpu": 10,
|
||||
},
|
||||
)
|
||||
|
||||
@ -157,6 +161,7 @@ jax_test(
|
||||
srcs = ["xmap_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 4,
|
||||
"tpu": 4,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
@ -210,7 +215,7 @@ jax_test(
|
||||
srcs = ["image_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"gpu": 20,
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
@ -276,6 +281,7 @@ jax_test(
|
||||
srcs = ["jet_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
},
|
||||
deps = [
|
||||
"//jax:jet",
|
||||
@ -288,7 +294,7 @@ jax_test(
|
||||
srcs = ["lax_control_flow_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"gpu": 20,
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
@ -361,7 +367,7 @@ jax_test(
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"gpu": 40,
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
@ -446,7 +452,7 @@ jax_test(
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"gpu": 40,
|
||||
"tpu": 10,
|
||||
"iree": 20,
|
||||
},
|
||||
@ -504,7 +510,7 @@ jax_test(
|
||||
srcs = ["pmap_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
"gpu": 5,
|
||||
"gpu": 10,
|
||||
"tpu": 5,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
@ -579,7 +585,7 @@ jax_test(
|
||||
main = "random_test.py",
|
||||
shard_count = {
|
||||
"cpu": 30,
|
||||
"gpu": 20,
|
||||
"gpu": 40,
|
||||
"tpu": 20,
|
||||
"iree": 20,
|
||||
},
|
||||
@ -618,6 +624,7 @@ jax_test(
|
||||
], # Test times out under asan/tsan.
|
||||
},
|
||||
shard_count = {
|
||||
"gpu": 10,
|
||||
"tpu": 5,
|
||||
},
|
||||
)
|
||||
@ -627,7 +634,7 @@ jax_test(
|
||||
srcs = ["scipy_stats_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"gpu": 20,
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
@ -655,7 +662,7 @@ jax_test(
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 20,
|
||||
"gpu": 40,
|
||||
"tpu": 10,
|
||||
"iree": 10,
|
||||
},
|
||||
@ -668,6 +675,9 @@ jax_test(
|
||||
name = "sparsify_test",
|
||||
srcs = ["sparsify_test.py"],
|
||||
args = ["--jax_bcoo_cusparse_lowering=true"],
|
||||
shard_count = {
|
||||
"gpu": 20,
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
],
|
||||
@ -787,6 +797,9 @@ jax_test(
|
||||
"tpu", # On TPU we always use outfeed
|
||||
],
|
||||
main = "host_callback_test.py",
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax:experimental_host_callback",
|
||||
|
Loading…
x
Reference in New Issue
Block a user