diff --git a/tests/BUILD b/tests/BUILD index 11e382b1a..46345b647 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1072,7 +1072,7 @@ jax_multiplatform_test( srcs = ["checkify_test.py"], shard_count = { "gpu": 2, - "tpu": 2, + "tpu": 4, }, )