From 71360edf90159edcacef1191a967f93592778737 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 9 Nov 2022 20:31:32 -0800 Subject: [PATCH] Bump the shard count for TPU to avoid timeouts PiperOrigin-RevId: 487421018 --- tests/BUILD | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 0023204ed..5c1b4dd7e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -384,7 +384,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 10, + "tpu": 40, "iree": 10, }, ) @@ -452,7 +452,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 20, + "tpu": 40, "iree": 40, }, deps = [":lax_test_lib"], @@ -661,7 +661,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 40, + "tpu": 50, }, ) @@ -688,7 +688,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 40, + "tpu": 50, "iree": 10, }, deps = [