Bump the shard count for TPU to avoid timeouts

PiperOrigin-RevId: 487421018
This commit is contained in:
Yash Katariya 2022-11-09 20:31:32 -08:00 committed by jax authors
parent e42e52d4aa
commit 71360edf90

View File

@ -384,7 +384,7 @@ jax_test(
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 10,
"tpu": 40,
"iree": 10,
},
)
@ -452,7 +452,7 @@ jax_test(
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
"tpu": 40,
"iree": 40,
},
deps = [":lax_test_lib"],
@ -661,7 +661,7 @@ jax_test(
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"tpu": 50,
},
)
@ -688,7 +688,7 @@ jax_test(
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"tpu": 50,
"iree": 10,
},
deps = [