diff --git a/tests/BUILD b/tests/BUILD index 33370f1a1..dd80984b9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -935,7 +935,7 @@ jax_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, "tpu": 10, },