diff --git a/tests/BUILD b/tests/BUILD index a8c2f83c6..fc5c13207 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -587,7 +587,7 @@ jax_test( shard_count = { "cpu": 30, "gpu": 40, - "tpu": 20, + "tpu": 40, "iree": 20, }, ) @@ -664,7 +664,7 @@ jax_test( shard_count = { "cpu": 10, "gpu": 40, - "tpu": 10, + "tpu": 40, "iree": 10, }, deps = [