Add jax_array coverage to debug_nans_test

PiperOrigin-RevId: 478079509
This commit is contained in:
Yash Katariya 2022-09-30 14:20:57 -07:00 committed by jax authors
parent ec41de2c9b
commit fb8558cfdd
2 changed files with 2 additions and 6 deletions

View File

@ -81,13 +81,9 @@ jax_test(
srcs = ["custom_object_test.py"],
)
py_test(
jax_test(
name = "debug_nans_test",
srcs = ["debug_nans_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(

View File

@ -150,7 +150,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.experimental.PartitionSpec('x')
p = pjit.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p)