Bump the shard count for TPU to avoid timeouts

PiperOrigin-RevId: 487421018
This commit is contained in:
Yash Katariya 2022-11-09 20:31:32 -08:00 committed by jax authors
parent e42e52d4aa
commit 71360edf90

View File

@ -384,7 +384,7 @@ jax_test(
shard_count = { shard_count = {
"cpu": 40, "cpu": 40,
"gpu": 40, "gpu": 40,
"tpu": 10, "tpu": 40,
"iree": 10, "iree": 10,
}, },
) )
@ -452,7 +452,7 @@ jax_test(
shard_count = { shard_count = {
"cpu": 40, "cpu": 40,
"gpu": 40, "gpu": 40,
"tpu": 20, "tpu": 40,
"iree": 40, "iree": 40,
}, },
deps = [":lax_test_lib"], deps = [":lax_test_lib"],
@ -661,7 +661,7 @@ jax_test(
shard_count = { shard_count = {
"cpu": 40, "cpu": 40,
"gpu": 40, "gpu": 40,
"tpu": 40, "tpu": 50,
}, },
) )
@ -688,7 +688,7 @@ jax_test(
shard_count = { shard_count = {
"cpu": 40, "cpu": 40,
"gpu": 40, "gpu": 40,
"tpu": 40, "tpu": 50,
"iree": 10, "iree": 10,
}, },
deps = [ deps = [