From fb8558cfdd10b40ceafa1cde7b29777e03b8b29e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 30 Sep 2022 14:20:57 -0700 Subject: [PATCH] Add jax_array coverage to debug_nans_test PiperOrigin-RevId: 478079509 --- tests/BUILD | 6 +----- tests/debug_nans_test.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 2f30937bd..4dd0404fd 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -81,13 +81,9 @@ jax_test( srcs = ["custom_object_test.py"], ) -py_test( +jax_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], - deps = [ - "//jax", - "//jax:test_util", - ], ) py_test( diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 31fac66a5..8240f7c08 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -150,7 +150,7 @@ class DebugNaNsTest(jtu.JaxTestCase): if jax.device_count() < 2: raise SkipTest("test requires >=2 devices") - p = jax.experimental.PartitionSpec('x') + p = pjit.PartitionSpec('x') f = pjit.pjit(lambda x: 0. / x, in_axis_resources=p, out_axis_resources=p)