diff --git a/tests/BUILD b/tests/BUILD index b2c34402b..ff686aaa1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -148,7 +148,7 @@ jax_test( shard_count = { "cpu": 10, "gpu": 10, - "tpu": 10, + "tpu": 20, }, ) @@ -540,6 +540,9 @@ jax_test( jax_test( name = "nn_test", srcs = ["nn_test.py"], + shard_count = { + "tpu": 10, + }, ) jax_test(