diff --git a/tests/BUILD b/tests/BUILD index 0023204ed..5c1b4dd7e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -384,7 +384,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 10, + "tpu": 40, "iree": 10, }, ) @@ -452,7 +452,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 20, + "tpu": 40, "iree": 40, }, deps = [":lax_test_lib"], @@ -661,7 +661,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 40, + "tpu": 50, }, ) @@ -688,7 +688,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 40, + "tpu": 50, "iree": 10, }, deps = [