diff --git a/tests/BUILD b/tests/BUILD index e75480316..9e7070e3c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -419,7 +419,7 @@ jax_test( shard_count = { "cpu": 40, "gpu": 40, - "tpu": 30, + "tpu": 40, "iree": 40, }, deps = [