Increase sharding to avoid timeouts

PiperOrigin-RevId: 626008096
This commit is contained in:
Adam Paszke 2024-04-18 06:03:40 -07:00 committed by jax authors
parent 8e3f5b1018
commit fa66e731e6

View File

@ -667,6 +667,7 @@ jax_test(
name = "nn_test",
srcs = ["nn_test.py"],
shard_count = {
"cpu": 10,
"tpu": 10,
"gpu": 10,
},