Shard nn_test on GPU to avoid timeouts

PiperOrigin-RevId: 606224790
This commit is contained in:
Adam Paszke 2024-02-12 05:48:47 -08:00 committed by jax authors
parent 7dd887dc84
commit 1b2227283b

View File

@ -675,6 +675,7 @@ jax_test(
srcs = ["nn_test.py"],
shard_count = {
"tpu": 10,
"gpu": 10,
},
)