Increase sharding of host_callback_test on TPU to fix CI flakiness.

PiperOrigin-RevId: 533451822
This commit is contained in:
Peter Hawkins 2023-05-19 07:44:18 -07:00 committed by jax authors
parent acc527d011
commit 1d20d2f301

View File

@ -914,6 +914,9 @@ jax_test(
name = "host_callback_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=true"],
shard_count = {
"tpu": 5,
},
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",